Skip to content

Commit 81ac87e

Browse files
committed
avoid re-initialization
1 parent e3c958a commit 81ac87e

File tree

2 files changed

+32
-46
lines changed

2 files changed

+32
-46
lines changed

mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ void populateExpressionPatterns(RewritePatternSet &patterns);
3131
/// Populates 'patterns' with func-related patterns.
3232
void populateFuncPatterns(RewritePatternSet &patterns);
3333

34+
/// Populates `patterns` with patterns to add reflection map for EmitC classes.
35+
void populateAddReflectionMapPatterns(RewritePatternSet &patterns,
36+
StringRef namedAttribute);
37+
3438
} // namespace emitc
3539
} // namespace mlir
3640

mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp

Lines changed: 28 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -80,29 +80,10 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
8080

8181
LogicalResult matchAndRewrite(mlir::emitc::ClassOp classOp,
8282
PatternRewriter &rewriter) const override {
83-
mlir::MLIRContext *context = rewriter.getContext();
84-
emitc::OpaqueType stringViewType =
85-
mlir::emitc::OpaqueType::get(rewriter.getContext(), "std::string_view");
86-
emitc::OpaqueType charType =
87-
mlir::emitc::OpaqueType::get(rewriter.getContext(), "char");
88-
emitc::OpaqueType mapType = mlir::emitc::OpaqueType::get(
89-
rewriter.getContext(), "const std::map<std::string, char*>");
90-
91-
FunctionType funcType =
92-
rewriter.getFunctionType({stringViewType}, {charType});
93-
emitc::FuncOp executeFunc =
94-
classOp.lookupSymbol<mlir::emitc::FuncOp>("execute");
95-
if (executeFunc)
96-
rewriter.setInsertionPoint(executeFunc);
97-
else
98-
classOp.emitError() << "ClassOp must contain a function named 'execute' "
99-
"to add reflection map";
83+
MLIRContext *context = rewriter.getContext();
10084

101-
emitc::FuncOp getBufferFunc = rewriter.create<mlir::emitc::FuncOp>(
102-
classOp.getLoc(), "getBufferForName", funcType);
103-
104-
Block *funcBody = getBufferFunc.addEntryBlock();
105-
rewriter.setInsertionPointToStart(funcBody);
85+
emitc::OpaqueType mapType = mlir::emitc::OpaqueType::get(
86+
context, "const std::map<std::string, char*>");
10687

10788
// Collect all field names
10889
std::vector<std::pair<std::string, std::string>> fieldNames;
@@ -129,44 +110,45 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
129110
}
130111
});
131112

113+
// Construct the map initializer string
132114
std::string mapInitializer = "{ ";
133115
for (size_t i = 0; i < fieldNames.size(); ++i) {
134116
mapInitializer += " { \"" + fieldNames[i].first + "\", " +
135117
"reinterpret_cast<char*>(&" + fieldNames[i].second +
136-
")",
137-
mapInitializer += " }";
118+
")";
119+
mapInitializer += " }";
138120
if (i < fieldNames.size() - 1)
139121
mapInitializer += ", ";
140122
}
141123
mapInitializer += " }";
142124

143-
emitc::OpaqueType iteratorType = mlir::emitc::OpaqueType::get(
144-
context, "std::map<std::string, char*>::const_iterator");
125+
emitc::OpaqueType returnType = mlir::emitc::OpaqueType::get(
126+
context, "const std::map<std::string, char*>");
127+
128+
emitc::FuncOp executeFunc =
129+
classOp.lookupSymbol<mlir::emitc::FuncOp>("execute");
130+
if (executeFunc)
131+
rewriter.setInsertionPoint(executeFunc);
132+
else
133+
classOp.emitError() << "ClassOp must contain a function named 'execute' "
134+
"to add reflection map";
135+
136+
// Create the getFeatures function
137+
emitc::FuncOp getFeaturesFunc = rewriter.create<mlir::emitc::FuncOp>(
138+
classOp.getLoc(), "getFeatures",
139+
rewriter.getFunctionType({}, {returnType}));
140+
141+
// Add the body of the getFeatures function
142+
Block *funcBody = getFeaturesFunc.addEntryBlock();
143+
rewriter.setInsertionPointToStart(funcBody);
145144

145+
// Create the constant map
146146
emitc::ConstantOp bufferMap = rewriter.create<emitc::ConstantOp>(
147147
classOp.getLoc(), mapType,
148148
emitc::OpaqueAttr::get(context, mapInitializer));
149149

150-
mlir::Value nameArg = getBufferFunc.getArgument(0);
151-
emitc::CallOpaqueOp it = rewriter.create<emitc::CallOpaqueOp>(
152-
classOp.getLoc(), iteratorType, rewriter.getStringAttr("find"),
153-
mlir::ValueRange{bufferMap.getResult(), nameArg});
154-
emitc::CallOpaqueOp endIt = rewriter.create<emitc::CallOpaqueOp>(
155-
classOp.getLoc(), iteratorType, rewriter.getStringAttr("end"),
156-
bufferMap.getResult());
157-
emitc::CallOpaqueOp isEnd = rewriter.create<emitc::CallOpaqueOp>(
158-
classOp.getLoc(), rewriter.getI1Type(),
159-
"operator==", mlir::ValueRange{it.getResult(0), endIt.getResult(0)});
160-
emitc::ConstantOp nullPtr = rewriter.create<emitc::ConstantOp>(
161-
classOp.getLoc(), charType, emitc::OpaqueAttr::get(context, "nullptr"));
162-
emitc::CallOpaqueOp second = rewriter.create<emitc::CallOpaqueOp>(
163-
classOp.getLoc(), charType, "second", it.getResult(0));
164-
165-
emitc::ConditionalOp result = rewriter.create<emitc::ConditionalOp>(
166-
classOp.getLoc(), charType, isEnd.getResult(0), nullPtr.getResult(),
167-
second.getResult(0));
168-
169-
rewriter.create<emitc::ReturnOp>(classOp.getLoc(), result.getResult());
150+
rewriter.create<mlir::emitc::ReturnOp>(classOp.getLoc(),
151+
bufferMap.getResult());
170152

171153
return success();
172154
}

0 commit comments

Comments
 (0)