1- // ===- ConvertFuncToClass .cpp - Convert functions to classes -------------===//
1+ // ===- WrapFuncInClass .cpp - Wrap Emitc Funcs in classes -------------===//
22//
33// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44// See https://llvm.org/LICENSE.txt for license information.
55// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66//
77// ===----------------------------------------------------------------------===//
88
9- #include " mlir-c/Rewrite.h"
109#include " mlir/Dialect/EmitC/IR/EmitC.h"
1110#include " mlir/Dialect/EmitC/Transforms/Passes.h"
1211#include " mlir/Dialect/EmitC/Transforms/Transforms.h"
1312#include " mlir/IR/Attributes.h"
1413#include " mlir/IR/Builders.h"
1514#include " mlir/IR/BuiltinAttributes.h"
1615#include " mlir/IR/PatternMatch.h"
17- #include " mlir/IR/TypeRange.h"
18- #include " mlir/IR/Value.h"
19- #include " mlir/Pass/Pass.h"
20- #include " mlir/Transforms/DialectConversion.h"
21- #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
22- #include " llvm/ADT/StringRef.h"
23- #include " llvm/Support/GraphWriter.h"
24- #include " llvm/Support/LogicalResult.h"
25- #include < string>
16+ #include " mlir/Transforms/WalkPatternRewriteDriver.h"
17+
18+ using namespace mlir ;
19+ using namespace emitc ;
2620
2721namespace mlir {
2822namespace emitc {
29-
3023#define GEN_PASS_DEF_WRAPFUNCINCLASSPASS
3124#include " mlir/Dialect/EmitC/Transforms/Passes.h.inc"
3225
3326namespace {
34-
3527struct WrapFuncInClassPass
3628 : public impl::WrapFuncInClassPassBase<WrapFuncInClassPass> {
3729 using WrapFuncInClassPassBase::WrapFuncInClassPassBase;
3830 void runOnOperation () override {
3931 Operation *rootOp = getOperation ();
40- MLIRContext *context = rootOp->getContext ();
4132
42- RewritePatternSet patterns (context );
33+ RewritePatternSet patterns (& getContext () );
4334 populateFuncPatterns (patterns, namedAttribute);
4435
45- if (failed (applyPatternsGreedily (rootOp, std::move (patterns))))
46- return signalPassFailure ();
47- }
48- void getDependentDialects (DialectRegistry ®istry) const override {
49- registry.insert <emitc::EmitCDialect>();
36+ walkAndApplyPatterns (rootOp, std::move (patterns));
5037 }
5138};
5239
5340} // namespace
54-
5541} // namespace emitc
5642} // namespace mlir
5743
58- using namespace mlir ;
59- using namespace mlir ::emitc;
60-
6144class WrapFuncInClass : public OpRewritePattern <emitc::FuncOp> {
62- private:
63- std::string attributeName;
64-
6545public:
66- WrapFuncInClass (MLIRContext *context, const std::string & attrName)
46+ WrapFuncInClass (MLIRContext *context, StringRef attrName)
6747 : OpRewritePattern<emitc::FuncOp>(context), attributeName(attrName) {}
6848
6949 LogicalResult matchAndRewrite (emitc::FuncOp funcOp,
7050 PatternRewriter &rewriter) const override {
71- if (funcOp->getParentOfType <emitc::ClassOp>()) {
72- return failure ();
73- }
51+
7452 auto className = funcOp.getSymNameAttr ().str () + " Class" ;
75- mlir::emitc::ClassOp newClassOp =
76- rewriter.create <emitc::ClassOp>(funcOp.getLoc (), className);
53+ ClassOp newClassOp = rewriter.create <ClassOp>(funcOp.getLoc (), className);
7754
7855 SmallVector<std::pair<StringAttr, TypeAttr>> fields;
7956 rewriter.createBlock (&newClassOp.getBody ());
@@ -84,19 +61,11 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
8461 StringAttr fieldName;
8562 Attribute argAttr = nullptr ;
8663
64+ fieldName = rewriter.getStringAttr (" fieldName" + std::to_string (idx));
8765 if (argAttrs && idx < argAttrs->size ()) {
8866 if (DictionaryAttr dictAttr =
89- dyn_cast<mlir::DictionaryAttr>((*argAttrs)[idx])) {
90- if (auto namedAttr = dictAttr.getNamed (attributeName)) {
91- Attribute nv = namedAttr->getValue ();
92- fieldName = cast<mlir::StringAttr>(cast<mlir::ArrayAttr>(nv)[0 ]);
93- argAttr = (*argAttrs)[idx];
94- }
95- }
96- }
97-
98- if (!fieldName) {
99- fieldName = rewriter.getStringAttr (" fieldName" + std::to_string (idx));
67+ dyn_cast<mlir::DictionaryAttr>((*argAttrs)[idx]))
68+ argAttr = (*argAttrs)[idx];
10069 }
10170
10271 TypeAttr typeAttr = TypeAttr::get (val.getType ());
@@ -106,19 +75,17 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
10675 }
10776
10877 rewriter.setInsertionPointToEnd (&newClassOp.getBody ().front ());
109- MLIRContext *funcContext = funcOp.getContext ();
110- ArrayRef<Type> inputTypes = funcOp.getFunctionType ().getInputs ();
111- ArrayRef<Type> results = funcOp.getFunctionType ().getResults ();
112- FunctionType funcType = FunctionType::get (funcContext, inputTypes, results);
78+ FunctionType funcType = funcOp.getFunctionType ();
11379 Location loc = funcOp.getLoc ();
114- FuncOp newFuncOp = rewriter. create <emitc::FuncOp>(
115- loc, rewriter. getStringAttr (" execute" ), funcType);
80+ FuncOp newFuncOp =
81+ rewriter. create <emitc::FuncOp>( loc, (" execute" ), funcType);
11682
11783 rewriter.createBlock (&newFuncOp.getBody ());
11884 newFuncOp.getBody ().takeBody (funcOp.getBody ());
11985
12086 rewriter.setInsertionPointToStart (&newFuncOp.getBody ().front ());
12187 std::vector<Value> newArguments;
88+ newArguments.reserve (fields.size ());
12289 for (auto &[fieldName, attr] : fields) {
12390 GetFieldOp arg =
12491 rewriter.create <emitc::GetFieldOp>(loc, attr.getValue (), fieldName);
@@ -132,15 +99,18 @@ class WrapFuncInClass : public OpRewritePattern<emitc::FuncOp> {
13299
133100 llvm::BitVector argsToErase (newFuncOp.getNumArguments (), true );
134101 if (failed (newFuncOp.eraseArguments (argsToErase))) {
135- newFuncOp->emitOpError (" Failed to erase all arguments using BitVector. " );
102+ newFuncOp->emitOpError (" failed to erase all arguments using BitVector" );
136103 }
137104
138105 rewriter.replaceOp (funcOp, newClassOp);
139106 return success ();
140107 }
108+
109+ private:
110+ StringRef attributeName;
141111};
142112
143113void mlir::emitc::populateFuncPatterns (RewritePatternSet &patterns,
144- const std::string & namedAttribute) {
114+ StringRef namedAttribute) {
145115 patterns.add <WrapFuncInClass>(patterns.getContext (), namedAttribute);
146116}
0 commit comments