1212#include " mlir/Dialect/EmitC/Transforms/Transforms.h"
1313#include " mlir/IR/Attributes.h"
1414#include " mlir/IR/Builders.h"
15- #include " mlir/IR/IRMapping.h"
1615#include " mlir/IR/PatternMatch.h"
1716#include " mlir/IR/TypeRange.h"
18- #include " mlir/IR/ValueRange.h"
1917#include " mlir/Pass/Pass.h"
2018#include " mlir/Transforms/DialectConversion.h"
2119#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
2220#include " llvm/Support/GraphWriter.h"
21+ #include " llvm/Support/LogicalResult.h"
2322
2423namespace mlir {
2524namespace emitc {
@@ -31,12 +30,13 @@ namespace {
3130
3231struct WrapFuncInClassPass
3332 : public impl::WrapFuncInClassPassBase<WrapFuncInClassPass> {
33+ using WrapFuncInClassPassBase::WrapFuncInClassPassBase;
3434 void runOnOperation () override {
3535 Operation *rootOp = getOperation ();
3636 MLIRContext *context = rootOp->getContext ();
3737
3838 RewritePatternSet patterns (context);
39- populateFuncPatterns (patterns);
39+ populateFuncPatterns (patterns, namedAttribute );
4040
4141 if (failed (applyPatternsGreedily (rootOp, std::move (patterns))))
4242 return signalPassFailure ();
@@ -54,16 +54,13 @@ struct WrapFuncInClassPass
5454using namespace mlir ;
5555using namespace mlir ::emitc;
5656
57- static bool validOp (Operation &opToClone) {
58- return isa<emitc::ConstantOp>(opToClone) ||
59- isa<emitc::SubscriptOp>(opToClone) || isa<emitc::LoadOp>(opToClone) ||
60- isa<emitc::AddOp>(opToClone) || isa<emitc::AssignOp>(opToClone) ||
61- isa<emitc::ReturnOp>(opToClone);
62- }
63-
6457class WrapFuncInClass : public OpRewritePattern <emitc::FuncOp> {
58+ private:
59+ std::string attributeName;
60+
6561public:
66- using OpRewritePattern<emitc::FuncOp>::OpRewritePattern;
62+ WrapFuncInClass (MLIRContext *context, const std::string &attrName)
63+ : OpRewritePattern<emitc::FuncOp>(context), attributeName(attrName) {}
6764
6865 LogicalResult matchAndRewrite (emitc::FuncOp funcOp,
6966 PatternRewriter &rewriter) const override {
@@ -79,23 +76,25 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
7976 rewriter.setInsertionPointToStart (&newClassOp.getBody ().front ());
8077
8178 auto argAttrs = funcOp.getArgAttrs ();
82-
83- for (const auto &[arg, val] : (zip (*argAttrs, funcOp.getArguments ()))) {
84- // FIXME:How can we avoid hardcoding this name?
85- // Should we loop through the dictionary and check for each named
86- // attribute if attr.getName().getValue().contains("tf_saved_model")
87- if (auto namedAttr = dyn_cast<mlir::DictionaryAttr>(arg).getNamed (
88- " tf_saved_model.index_path" )) {
89- Attribute nv = namedAttr->getValue ();
90- StringAttr fieldName =
91- cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0 ]);
92- TypeAttr typeAttr = TypeAttr::get (val.getType ());
93- fields.push_back ({fieldName, typeAttr});
94-
95- rewriter.create <emitc::FieldOp>(funcOp.getLoc (), fieldName, typeAttr,
96- /* attributes*/ arg);
97- } else
98- funcOp->emitOpError (" Only Covers TF models" );
79+ if (argAttrs) {
80+ for (const auto &[arg, val] :
81+ llvm::zip (*argAttrs, funcOp.getArguments ())) {
82+ if (auto namedAttr =
83+ dyn_cast<mlir::DictionaryAttr>(arg).getNamed (attributeName)) {
84+ Attribute nv = namedAttr->getValue ();
85+ StringAttr fieldName =
86+ cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0 ]);
87+ TypeAttr typeAttr = TypeAttr::get (val.getType ());
88+ fields.push_back ({fieldName, typeAttr});
89+
90+ rewriter.create <emitc::FieldOp>(funcOp.getLoc (), fieldName, typeAttr,
91+ /* attributes*/ arg);
92+ }
93+ }
94+ } else {
95+ funcOp->emitOpError (" arguments should have attributes so we can "
96+ " initialize class fields." );
97+ return failure ();
9998 }
10099
101100 rewriter.setInsertionPointToEnd (&newClassOp.getBody ().front ());
@@ -107,19 +106,20 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
107106 FuncOp newFuncOp = rewriter.create <emitc::FuncOp>(
108107 loc, rewriter.getStringAttr (" execute" ), funcType);
109108
110- rewriter.setInsertionPointToStart (newFuncOp.addEntryBlock ());
109+ rewriter.createBlock (&newFuncOp.getBody ());
110+ newFuncOp.getBody ().takeBody (funcOp.getBody ());
111111
112+ rewriter.setInsertionPointToStart (&newFuncOp.getBody ().front ());
112113 std::vector<Value> newArguments;
113114 for (auto [fieldName, attr] : fields) {
114115 auto arg =
115116 rewriter.create <emitc::GetFieldOp>(loc, attr.getValue (), fieldName);
116117 newArguments.push_back (arg);
117118 }
118119
119- IRMapping mapper;
120120 for (auto [oldArg, newArg] :
121- llvm::zip (funcOp .getArguments (), newArguments)) {
122- mapper. map (oldArg, newArg);
121+ llvm::zip (newFuncOp .getArguments (), newArguments)) {
122+ rewriter. replaceAllUsesWith (oldArg, newArg);
123123 }
124124
125125 while (!newFuncOp.getArguments ().empty ()) {
@@ -128,23 +128,12 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
128128 }
129129 }
130130
131- // TODO: The mapper is easier to use but cloning is more expensive than
132- // moving the body. Working on changing this portion to move the body
133- // instead
134- auto body = llvm::make_early_inc_range (funcOp.getBody ().front ());
135- for (Operation &opToClone : body) {
136- if (validOp (opToClone)) {
137- rewriter.clone (opToClone, mapper);
138- } else {
139- opToClone.emitOpError (" Unsupported operation found" );
140- }
141- }
142-
143131 rewriter.replaceOp (funcOp, newClassOp);
144132 return funcOp->use_empty () ? success () : failure ();
145133 }
146134};
147135
148- void mlir::emitc::populateFuncPatterns (RewritePatternSet &patterns) {
149- patterns.add <WrapFuncInClass>(patterns.getContext ());
136+ void mlir::emitc::populateFuncPatterns (RewritePatternSet &patterns,
137+ const std::string &namedAttribute) {
138+ patterns.add <WrapFuncInClass>(patterns.getContext (), namedAttribute);
150139}
0 commit comments