12
12
#include " mlir/Dialect/EmitC/Transforms/Transforms.h"
13
13
#include " mlir/IR/Attributes.h"
14
14
#include " mlir/IR/Builders.h"
15
- #include " mlir/IR/IRMapping.h"
16
15
#include " mlir/IR/PatternMatch.h"
17
16
#include " mlir/IR/TypeRange.h"
18
- #include " mlir/IR/ValueRange.h"
19
17
#include " mlir/Pass/Pass.h"
20
18
#include " mlir/Transforms/DialectConversion.h"
21
19
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
22
20
#include " llvm/Support/GraphWriter.h"
21
+ #include " llvm/Support/LogicalResult.h"
23
22
24
23
namespace mlir {
25
24
namespace emitc {
@@ -31,12 +30,13 @@ namespace {
31
30
32
31
struct WrapFuncInClassPass
33
32
: public impl::WrapFuncInClassPassBase<WrapFuncInClassPass> {
33
+ using WrapFuncInClassPassBase::WrapFuncInClassPassBase;
34
34
void runOnOperation () override {
35
35
Operation *rootOp = getOperation ();
36
36
MLIRContext *context = rootOp->getContext ();
37
37
38
38
RewritePatternSet patterns (context);
39
- populateFuncPatterns (patterns);
39
+ populateFuncPatterns (patterns, namedAttribute );
40
40
41
41
if (failed (applyPatternsGreedily (rootOp, std::move (patterns))))
42
42
return signalPassFailure ();
@@ -54,16 +54,13 @@ struct WrapFuncInClassPass
54
54
using namespace mlir ;
55
55
using namespace mlir ::emitc;
56
56
57
- static bool validOp (Operation &opToClone) {
58
- return isa<emitc::ConstantOp>(opToClone) ||
59
- isa<emitc::SubscriptOp>(opToClone) || isa<emitc::LoadOp>(opToClone) ||
60
- isa<emitc::AddOp>(opToClone) || isa<emitc::AssignOp>(opToClone) ||
61
- isa<emitc::ReturnOp>(opToClone);
62
- }
63
-
64
57
class WrapFuncInClass : public OpRewritePattern <emitc::FuncOp> {
58
+ private:
59
+ std::string attributeName;
60
+
65
61
public:
66
- using OpRewritePattern<emitc::FuncOp>::OpRewritePattern;
62
+ WrapFuncInClass (MLIRContext *context, const std::string &attrName)
63
+ : OpRewritePattern<emitc::FuncOp>(context), attributeName(attrName) {}
67
64
68
65
LogicalResult matchAndRewrite (emitc::FuncOp funcOp,
69
66
PatternRewriter &rewriter) const override {
@@ -79,23 +76,25 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
79
76
rewriter.setInsertionPointToStart (&newClassOp.getBody ().front ());
80
77
81
78
auto argAttrs = funcOp.getArgAttrs ();
82
-
83
- for (const auto &[arg, val] : (zip (*argAttrs, funcOp.getArguments ()))) {
84
- // FIXME:How can we avoid hardcoding this name?
85
- // Should we loop through the dictionary and check for each named
86
- // attribute if attr.getName().getValue().contains("tf_saved_model")
87
- if (auto namedAttr = dyn_cast<mlir::DictionaryAttr>(arg).getNamed (
88
- " tf_saved_model.index_path" )) {
89
- Attribute nv = namedAttr->getValue ();
90
- StringAttr fieldName =
91
- cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0 ]);
92
- TypeAttr typeAttr = TypeAttr::get (val.getType ());
93
- fields.push_back ({fieldName, typeAttr});
94
-
95
- rewriter.create <emitc::FieldOp>(funcOp.getLoc (), fieldName, typeAttr,
96
- /* attributes*/ arg);
97
- } else
98
- funcOp->emitOpError (" Only Covers TF models" );
79
+ if (argAttrs) {
80
+ for (const auto &[arg, val] :
81
+ llvm::zip (*argAttrs, funcOp.getArguments ())) {
82
+ if (auto namedAttr =
83
+ dyn_cast<mlir::DictionaryAttr>(arg).getNamed (attributeName)) {
84
+ Attribute nv = namedAttr->getValue ();
85
+ StringAttr fieldName =
86
+ cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0 ]);
87
+ TypeAttr typeAttr = TypeAttr::get (val.getType ());
88
+ fields.push_back ({fieldName, typeAttr});
89
+
90
+ rewriter.create <emitc::FieldOp>(funcOp.getLoc (), fieldName, typeAttr,
91
+ /* attributes*/ arg);
92
+ }
93
+ }
94
+ } else {
95
+ funcOp->emitOpError (" arguments should have attributes so we can "
96
+ " initialize class fields." );
97
+ return failure ();
99
98
}
100
99
101
100
rewriter.setInsertionPointToEnd (&newClassOp.getBody ().front ());
@@ -107,19 +106,20 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
107
106
FuncOp newFuncOp = rewriter.create <emitc::FuncOp>(
108
107
loc, rewriter.getStringAttr (" execute" ), funcType);
109
108
110
- rewriter.setInsertionPointToStart (newFuncOp.addEntryBlock ());
109
+ rewriter.createBlock (&newFuncOp.getBody ());
110
+ newFuncOp.getBody ().takeBody (funcOp.getBody ());
111
111
112
+ rewriter.setInsertionPointToStart (&newFuncOp.getBody ().front ());
112
113
std::vector<Value> newArguments;
113
114
for (auto [fieldName, attr] : fields) {
114
115
auto arg =
115
116
rewriter.create <emitc::GetFieldOp>(loc, attr.getValue (), fieldName);
116
117
newArguments.push_back (arg);
117
118
}
118
119
119
- IRMapping mapper;
120
120
for (auto [oldArg, newArg] :
121
- llvm::zip (funcOp .getArguments (), newArguments)) {
122
- mapper. map (oldArg, newArg);
121
+ llvm::zip (newFuncOp .getArguments (), newArguments)) {
122
+ rewriter. replaceAllUsesWith (oldArg, newArg);
123
123
}
124
124
125
125
while (!newFuncOp.getArguments ().empty ()) {
@@ -128,23 +128,12 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
128
128
}
129
129
}
130
130
131
- // TODO: The mapper is easier to use but cloning is more expensive than
132
- // moving the body. Working on changing this portion to move the body
133
- // instead
134
- auto body = llvm::make_early_inc_range (funcOp.getBody ().front ());
135
- for (Operation &opToClone : body) {
136
- if (validOp (opToClone)) {
137
- rewriter.clone (opToClone, mapper);
138
- } else {
139
- opToClone.emitOpError (" Unsupported operation found" );
140
- }
141
- }
142
-
143
131
rewriter.replaceOp (funcOp, newClassOp);
144
132
return funcOp->use_empty () ? success () : failure ();
145
133
}
146
134
};
147
135
148
- void mlir::emitc::populateFuncPatterns (RewritePatternSet &patterns) {
149
- patterns.add <WrapFuncInClass>(patterns.getContext ());
136
+ void mlir::emitc::populateFuncPatterns (RewritePatternSet &patterns,
137
+ const std::string &namedAttribute) {
138
+ patterns.add <WrapFuncInClass>(patterns.getContext (), namedAttribute);
150
139
}
0 commit comments