Skip to content

Commit f1cb9df

Browse files
committed
Revert "avoid re-initialization"
This reverts commit f2dee0d99bc1fb258b6cef57dc150cb637cc4ab3.
1 parent 81ac87e commit f1cb9df

File tree

2 files changed

+46
-32
lines changed

2 files changed

+46
-32
lines changed

mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,6 @@ void populateExpressionPatterns(RewritePatternSet &patterns);
3131
/// Populates 'patterns' with func-related patterns.
3232
void populateFuncPatterns(RewritePatternSet &patterns);
3333

34-
/// Populates `patterns` with patterns to add reflection map for EmitC classes.
35-
void populateAddReflectionMapPatterns(RewritePatternSet &patterns,
36-
StringRef namedAttribute);
37-
3834
} // namespace emitc
3935
} // namespace mlir
4036

mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp

Lines changed: 46 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,29 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
8080

8181
LogicalResult matchAndRewrite(mlir::emitc::ClassOp classOp,
8282
PatternRewriter &rewriter) const override {
83-
MLIRContext *context = rewriter.getContext();
84-
83+
mlir::MLIRContext *context = rewriter.getContext();
84+
emitc::OpaqueType stringViewType =
85+
mlir::emitc::OpaqueType::get(rewriter.getContext(), "std::string_view");
86+
emitc::OpaqueType charType =
87+
mlir::emitc::OpaqueType::get(rewriter.getContext(), "char");
8588
emitc::OpaqueType mapType = mlir::emitc::OpaqueType::get(
86-
context, "const std::map<std::string, char*>");
89+
rewriter.getContext(), "const std::map<std::string, char*>");
90+
91+
FunctionType funcType =
92+
rewriter.getFunctionType({stringViewType}, {charType});
93+
emitc::FuncOp executeFunc =
94+
classOp.lookupSymbol<mlir::emitc::FuncOp>("execute");
95+
if (executeFunc)
96+
rewriter.setInsertionPoint(executeFunc);
97+
else
98+
classOp.emitError() << "ClassOp must contain a function named 'execute' "
99+
"to add reflection map";
100+
101+
emitc::FuncOp getBufferFunc = rewriter.create<mlir::emitc::FuncOp>(
102+
classOp.getLoc(), "getBufferForName", funcType);
103+
104+
Block *funcBody = getBufferFunc.addEntryBlock();
105+
rewriter.setInsertionPointToStart(funcBody);
87106

88107
// Collect all field names
89108
std::vector<std::pair<std::string, std::string>> fieldNames;
@@ -110,45 +129,44 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
110129
}
111130
});
112131

113-
// Construct the map initializer string
114132
std::string mapInitializer = "{ ";
115133
for (size_t i = 0; i < fieldNames.size(); ++i) {
116134
mapInitializer += " { \"" + fieldNames[i].first + "\", " +
117135
"reinterpret_cast<char*>(&" + fieldNames[i].second +
118-
")";
119-
mapInitializer += " }";
136+
")",
137+
mapInitializer += " }";
120138
if (i < fieldNames.size() - 1)
121139
mapInitializer += ", ";
122140
}
123141
mapInitializer += " }";
124142

125-
emitc::OpaqueType returnType = mlir::emitc::OpaqueType::get(
126-
context, "const std::map<std::string, char*>");
127-
128-
emitc::FuncOp executeFunc =
129-
classOp.lookupSymbol<mlir::emitc::FuncOp>("execute");
130-
if (executeFunc)
131-
rewriter.setInsertionPoint(executeFunc);
132-
else
133-
classOp.emitError() << "ClassOp must contain a function named 'execute' "
134-
"to add reflection map";
135-
136-
// Create the getFeatures function
137-
emitc::FuncOp getFeaturesFunc = rewriter.create<mlir::emitc::FuncOp>(
138-
classOp.getLoc(), "getFeatures",
139-
rewriter.getFunctionType({}, {returnType}));
140-
141-
// Add the body of the getFeatures function
142-
Block *funcBody = getFeaturesFunc.addEntryBlock();
143-
rewriter.setInsertionPointToStart(funcBody);
143+
emitc::OpaqueType iteratorType = mlir::emitc::OpaqueType::get(
144+
context, "std::map<std::string, char*>::const_iterator");
144145

145-
// Create the constant map
146146
emitc::ConstantOp bufferMap = rewriter.create<emitc::ConstantOp>(
147147
classOp.getLoc(), mapType,
148148
emitc::OpaqueAttr::get(context, mapInitializer));
149149

150-
rewriter.create<mlir::emitc::ReturnOp>(classOp.getLoc(),
151-
bufferMap.getResult());
150+
mlir::Value nameArg = getBufferFunc.getArgument(0);
151+
emitc::CallOpaqueOp it = rewriter.create<emitc::CallOpaqueOp>(
152+
classOp.getLoc(), iteratorType, rewriter.getStringAttr("find"),
153+
mlir::ValueRange{bufferMap.getResult(), nameArg});
154+
emitc::CallOpaqueOp endIt = rewriter.create<emitc::CallOpaqueOp>(
155+
classOp.getLoc(), iteratorType, rewriter.getStringAttr("end"),
156+
bufferMap.getResult());
157+
emitc::CallOpaqueOp isEnd = rewriter.create<emitc::CallOpaqueOp>(
158+
classOp.getLoc(), rewriter.getI1Type(),
159+
"operator==", mlir::ValueRange{it.getResult(0), endIt.getResult(0)});
160+
emitc::ConstantOp nullPtr = rewriter.create<emitc::ConstantOp>(
161+
classOp.getLoc(), charType, emitc::OpaqueAttr::get(context, "nullptr"));
162+
emitc::CallOpaqueOp second = rewriter.create<emitc::CallOpaqueOp>(
163+
classOp.getLoc(), charType, "second", it.getResult(0));
164+
165+
emitc::ConditionalOp result = rewriter.create<emitc::ConditionalOp>(
166+
classOp.getLoc(), charType, isEnd.getResult(0), nullPtr.getResult(),
167+
second.getResult(0));
168+
169+
rewriter.create<emitc::ReturnOp>(classOp.getLoc(), result.getResult());
152170

153171
return success();
154172
}

0 commit comments

Comments
 (0)