Skip to content

Commit 3176910

Browse files
committed
Added tests for the wrap-emitc-func-in-class
1 parent 5610f24 commit 3176910

File tree

5 files changed

+72
-49
lines changed

5 files changed

+72
-49
lines changed

mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
2828
to a new `execute` method within the class.
2929
}];
3030
let dependentDialects = ["emitc::EmitCDialect"];
31+
let options = [Option<
32+
"namedAttribute", "named-attribute", "std::string", "\"\"",
33+
"Name of the attribute to look for field names on function arguments">];
3134
}
3235

3336
#endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES

mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ ExpressionOp createExpression(Operation *op, OpBuilder &builder);
2929
void populateExpressionPatterns(RewritePatternSet &patterns);
3030

3131
/// Populates 'patterns' with func-related patterns.
32-
void populateFuncPatterns(RewritePatternSet &patterns);
32+
void populateFuncPatterns(RewritePatternSet &patterns,
33+
const std::string &namedAttribute);
3334

3435
} // namespace emitc
3536
} // namespace mlir

mlir/lib/Dialect/EmitC/Transforms/WrapFuncInClass.cpp

Lines changed: 35 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,13 @@
1212
#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
1313
#include "mlir/IR/Attributes.h"
1414
#include "mlir/IR/Builders.h"
15-
#include "mlir/IR/IRMapping.h"
1615
#include "mlir/IR/PatternMatch.h"
1716
#include "mlir/IR/TypeRange.h"
18-
#include "mlir/IR/ValueRange.h"
1917
#include "mlir/Pass/Pass.h"
2018
#include "mlir/Transforms/DialectConversion.h"
2119
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2220
#include "llvm/Support/GraphWriter.h"
21+
#include "llvm/Support/LogicalResult.h"
2322

2423
namespace mlir {
2524
namespace emitc {
@@ -31,12 +30,13 @@ namespace {
3130

3231
struct WrapFuncInClassPass
3332
: public impl::WrapFuncInClassPassBase<WrapFuncInClassPass> {
33+
using WrapFuncInClassPassBase::WrapFuncInClassPassBase;
3434
void runOnOperation() override {
3535
Operation *rootOp = getOperation();
3636
MLIRContext *context = rootOp->getContext();
3737

3838
RewritePatternSet patterns(context);
39-
populateFuncPatterns(patterns);
39+
populateFuncPatterns(patterns, namedAttribute);
4040

4141
if (failed(applyPatternsGreedily(rootOp, std::move(patterns))))
4242
return signalPassFailure();
@@ -54,16 +54,13 @@ struct WrapFuncInClassPass
5454
using namespace mlir;
5555
using namespace mlir::emitc;
5656

57-
static bool validOp(Operation &opToClone) {
58-
return isa<emitc::ConstantOp>(opToClone) ||
59-
isa<emitc::SubscriptOp>(opToClone) || isa<emitc::LoadOp>(opToClone) ||
60-
isa<emitc::AddOp>(opToClone) || isa<emitc::AssignOp>(opToClone) ||
61-
isa<emitc::ReturnOp>(opToClone);
62-
}
63-
6457
class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
58+
private:
59+
std::string attributeName;
60+
6561
public:
66-
using OpRewritePattern<emitc::FuncOp>::OpRewritePattern;
62+
WrapFuncInClass(MLIRContext *context, const std::string &attrName)
63+
: OpRewritePattern<emitc::FuncOp>(context), attributeName(attrName) {}
6764

6865
LogicalResult matchAndRewrite(emitc::FuncOp funcOp,
6966
PatternRewriter &rewriter) const override {
@@ -79,23 +76,25 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
7976
rewriter.setInsertionPointToStart(&newClassOp.getBody().front());
8077

8178
auto argAttrs = funcOp.getArgAttrs();
82-
83-
for (const auto &[arg, val] : (zip(*argAttrs, funcOp.getArguments()))) {
84-
// FIXME:How can we avoid hardcoding this name?
85-
// Should we loop through the dictionary and check for each named
86-
// attribute if attr.getName().getValue().contains("tf_saved_model")
87-
if (auto namedAttr = dyn_cast<mlir::DictionaryAttr>(arg).getNamed(
88-
"tf_saved_model.index_path")) {
89-
Attribute nv = namedAttr->getValue();
90-
StringAttr fieldName =
91-
cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]);
92-
TypeAttr typeAttr = TypeAttr::get(val.getType());
93-
fields.push_back({fieldName, typeAttr});
94-
95-
rewriter.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr,
96-
/* attributes*/ arg);
97-
} else
98-
funcOp->emitOpError("Only Covers TF models");
79+
if (argAttrs) {
80+
for (const auto &[arg, val] :
81+
llvm::zip(*argAttrs, funcOp.getArguments())) {
82+
if (auto namedAttr =
83+
dyn_cast<mlir::DictionaryAttr>(arg).getNamed(attributeName)) {
84+
Attribute nv = namedAttr->getValue();
85+
StringAttr fieldName =
86+
cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0]);
87+
TypeAttr typeAttr = TypeAttr::get(val.getType());
88+
fields.push_back({fieldName, typeAttr});
89+
90+
rewriter.create<emitc::FieldOp>(funcOp.getLoc(), fieldName, typeAttr,
91+
/* attributes*/ arg);
92+
}
93+
}
94+
} else {
95+
funcOp->emitOpError("arguments should have attributes so we can "
96+
"initialize class fields.");
97+
return failure();
9998
}
10099

101100
rewriter.setInsertionPointToEnd(&newClassOp.getBody().front());
@@ -107,19 +106,20 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
107106
FuncOp newFuncOp = rewriter.create<emitc::FuncOp>(
108107
loc, rewriter.getStringAttr("execute"), funcType);
109108

110-
rewriter.setInsertionPointToStart(newFuncOp.addEntryBlock());
109+
rewriter.createBlock(&newFuncOp.getBody());
110+
newFuncOp.getBody().takeBody(funcOp.getBody());
111111

112+
rewriter.setInsertionPointToStart(&newFuncOp.getBody().front());
112113
std::vector<Value> newArguments;
113114
for (auto [fieldName, attr] : fields) {
114115
auto arg =
115116
rewriter.create<emitc::GetFieldOp>(loc, attr.getValue(), fieldName);
116117
newArguments.push_back(arg);
117118
}
118119

119-
IRMapping mapper;
120120
for (auto [oldArg, newArg] :
121-
llvm::zip(funcOp.getArguments(), newArguments)) {
122-
mapper.map(oldArg, newArg);
121+
llvm::zip(newFuncOp.getArguments(), newArguments)) {
122+
rewriter.replaceAllUsesWith(oldArg, newArg);
123123
}
124124

125125
while (!newFuncOp.getArguments().empty()) {
@@ -128,23 +128,12 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
128128
}
129129
}
130130

131-
// TODO: The mapper is easier to use but cloning is more expensive than
132-
// moving the body. Working on changing this portion to move the body
133-
// instead
134-
auto body = llvm::make_early_inc_range(funcOp.getBody().front());
135-
for (Operation &opToClone : body) {
136-
if (validOp(opToClone)) {
137-
rewriter.clone(opToClone, mapper);
138-
} else {
139-
opToClone.emitOpError("Unsupported operation found");
140-
}
141-
}
142-
143131
rewriter.replaceOp(funcOp, newClassOp);
144132
return funcOp->use_empty() ? success() : failure();
145133
}
146134
};
147135

148-
void mlir::emitc::populateFuncPatterns(RewritePatternSet &patterns) {
149-
patterns.add<WrapFuncInClass>(patterns.getContext());
136+
void mlir::emitc::populateFuncPatterns(RewritePatternSet &patterns,
137+
const std::string &namedAttribute) {
138+
patterns.add<WrapFuncInClass>(patterns.getContext(), namedAttribute);
150139
}

mlir/test/Dialect/EmitC/wrap_emitc_func_in_class.mlir

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
// RUN: mlir-opt %s --wrap-emitc-func-in-class
1+
// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=tf_saved_model.index_path' %s | FileCheck %s
22

33
module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
4-
emitc.func @main(%arg0: !emitc.array<1xf32> {tf_saved_model.index_path = ["another_feature"]}, %arg1: !emitc.array<1xf32> {tf_saved_model.index_path = ["some_feature"]}, %arg2: !emitc.array<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
4+
emitc.func @Model(%arg0: !emitc.array<1xf32> {tf_saved_model.index_path = ["another_feature"]}, %arg1: !emitc.array<1xf32> {tf_saved_model.index_path = ["some_feature"]}, %arg2: !emitc.array<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "serving_default_another_feature:0,serving_default_some_feature:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} {
55
%0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
66
%1 = subscript %arg1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
77
%2 = load %1 : <f32>
@@ -13,3 +13,25 @@ module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted."
1313
return
1414
}
1515
}
16+
17+
// CHECK: module attributes {tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} {
18+
// CHECK: emitc.class @MyModelClass {
19+
// CHECK: emitc.field @another_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["another_feature"]}
20+
// CHECK: emitc.field @some_feature : !emitc.array<1xf32> = {tf_saved_model.index_path = ["some_feature"]}
21+
// CHECK: emitc.field @output_0 : !emitc.array<1xf32> = {tf_saved_model.index_path = ["output_0"]}
22+
// CHECK: emitc.func @execute() {
23+
// CHECK: %{{[0-9]+}} = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
24+
// CHECK: %{{[0-9]+}} = get_field @another_feature : !emitc.array<1xf32>
25+
// CHECK: %{{[0-9]+}} = get_field @some_feature : !emitc.array<1xf32>
26+
// CHECK: %{{[0-9]+}} = get_field @output_0 : !emitc.array<1xf32>
27+
// CHECK: %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
28+
// CHECK: %{{[0-9]+}} = load %{{[0-9]+}} : <f32>
29+
// CHECK: %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
30+
// CHECK: %{{[0-9]+}} = load %{{[0-9]+}} : <f32>
31+
// CHECK: %{{[0-9]+}} = add %{{[0-9]+}}, %{{[0-9]+}} : (f32, f32) -> f32
32+
// CHECK: %{{[0-9]+}} = subscript %{{[0-9]+}}[%{{[0-9]+}}] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
33+
// CHECK: assign %{{[0-9]+}} : f32 to %{{[0-9]+}} : <f32>
34+
// CHECK: return
35+
// CHECK: }
36+
// CHECK: }
37+
// CHECK: }
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
// RUN: mlir-opt --wrap-emitc-func-in-class='named-attribute=tf_saved_model.index_path' %s 2>&1 | FileCheck %s
2+
3+
emitc.func @foo(%arg0 : i32) {
4+
emitc.call_opaque "bar" (%arg0) : (i32) -> ()
5+
emitc.return
6+
}
7+
8+
// CHECK: error: 'emitc.func' op arguments should have attributes so we can initialize class fields.

0 commit comments

Comments
 (0)