Skip to content

Commit b8ecc0c

Browse files
committed
Return the whole map
1 parent a165ba4 commit b8ecc0c

File tree

3 files changed

+33
-53
lines changed

3 files changed

+33
-53
lines changed

mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,15 +79,9 @@ def AddReflectionMapPass : Pass<"add-reflection-map", "ModuleOp"> {
7979
emitc.class @MyClass {
8080
emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.field_ref = ["another_feature"]}
8181
emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.field_ref = ["some_feature"]}
82-
emitc.func @getBufferForName(%name : !emitc.opaque<"std::string_view">) -> !emitc.opaque<"char*"> {
82+
emitc.func @getFeatures() {
8383
%0 = "emitc.constant"() <{value = #emitc.opaque<"{ { \22another_feature\22, reinterpret_cast<char*>(&fieldName0) }, { \22some_feature\22, reinterpret_cast<char*>(&fieldName1) } }">}> : () -> !emitc.opaque<"const std::map<std::string, char*>">
84-
%1 = call_opaque "find"(%0, %arg0) : (!emitc.opaque<"const std::map<std::string, char*>">, !emitc.opaque<"std::string_view">) -> !emitc.opaque<"std::map<std::string, char*>::const_iterator">
85-
%2 = call_opaque "end"(%0) : (!emitc.opaque<"const std::map<std::string, char*>">) -> !emitc.opaque<"std::map<std::string, char*>::const_iterator">
86-
%3 = call_opaque "operator=="(%1, %2) : (!emitc.opaque<"std::map<std::string, char*>::const_iterator">, !emitc.opaque<"std::map<std::string, char*>::const_iterator">) -> i1
87-
%4 = "emitc.constant"() <{value = #emitc.opaque<"nullptr">}> : () -> !emitc.opaque<"char">
88-
%5 = call_opaque "second"(%1) : (!emitc.opaque<"std::map<std::string, char*>::const_iterator">) -> !emitc.opaque<"char">
89-
%6 = conditional %3, %4, %5 : !emitc.opaque<"char">
90-
return %6 : !emitc.opaque<"char">
84+
return %0 : !emitc.opaque<"const std::map<std::string, char*>">
9185
}
9286
emitc.func @execute() { ... }
9387
}

mlir/include/mlir/Dialect/EmitC/Transforms/Transforms.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ void populateExpressionPatterns(RewritePatternSet &patterns);
3131
/// Populates 'patterns' with func-related patterns.
3232
void populateFuncPatterns(RewritePatternSet &patterns);
3333

34+
/// Populates `patterns` with patterns to add reflection map for EmitC classes.
35+
void populateAddReflectionMapPatterns(RewritePatternSet &patterns,
36+
StringRef namedAttribute);
37+
3438
} // namespace emitc
3539
} // namespace mlir
3640

mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp

Lines changed: 27 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -80,29 +80,11 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
8080

8181
LogicalResult matchAndRewrite(mlir::emitc::ClassOp classOp,
8282
PatternRewriter &rewriter) const override {
83-
mlir::MLIRContext *context = rewriter.getContext();
84-
emitc::OpaqueType stringViewType =
85-
mlir::emitc::OpaqueType::get(rewriter.getContext(), "std::string_view");
86-
emitc::OpaqueType charType =
87-
mlir::emitc::OpaqueType::get(rewriter.getContext(), "char");
88-
emitc::OpaqueType mapType = mlir::emitc::OpaqueType::get(
89-
rewriter.getContext(), "const std::map<std::string, char*>");
90-
91-
FunctionType funcType =
92-
rewriter.getFunctionType({stringViewType}, {charType});
93-
emitc::FuncOp executeFunc =
94-
classOp.lookupSymbol<mlir::emitc::FuncOp>("execute");
95-
if (executeFunc)
96-
rewriter.setInsertionPoint(executeFunc);
97-
else
98-
classOp.emitError() << "ClassOp must contain a function named 'execute' "
99-
"to add reflection map";
83+
MLIRContext *context = rewriter.getContext();
10084

101-
emitc::FuncOp getBufferFunc = rewriter.create<mlir::emitc::FuncOp>(
102-
classOp.getLoc(), "getBufferForName", funcType);
103-
104-
Block *funcBody = getBufferFunc.addEntryBlock();
105-
rewriter.setInsertionPointToStart(funcBody);
85+
// Define the opaque types
86+
emitc::OpaqueType mapType = mlir::emitc::OpaqueType::get(
87+
context, "const std::map<std::string, char*>");
10688

10789
// Collect all field names
10890
std::vector<std::pair<std::string, std::string>> fieldNames;
@@ -138,35 +120,35 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
138120
if (i < fieldNames.size() - 1)
139121
mapInitializer += ", ";
140122
}
141-
mapInitializer += " }}";
123+
mapInitializer += " }";
142124

143-
emitc::OpaqueType iteratorType = mlir::emitc::OpaqueType::get(
144-
context, "std::map<std::string, char*>::const_iterator");
125+
emitc::FuncOp executeFunc =
126+
classOp.lookupSymbol<mlir::emitc::FuncOp>("execute");
127+
if (executeFunc)
128+
rewriter.setInsertionPoint(executeFunc);
129+
else
130+
classOp.emitError() << "ClassOp must contain a function named 'execute' "
131+
"to add reflection map";
132+
133+
emitc::OpaqueType returnType = mlir::emitc::OpaqueType::get(
134+
context, "const std::map<std::string, char*>");
135+
136+
// Create the getFeatures function
137+
emitc::FuncOp getFeaturesFunc = rewriter.create<mlir::emitc::FuncOp>(
138+
classOp.getLoc(), "getFeatures",
139+
rewriter.getFunctionType({}, {returnType}));
140+
141+
// Add the body of the getFeatures function
142+
Block *funcBody = getFeaturesFunc.addEntryBlock();
143+
rewriter.setInsertionPointToStart(funcBody);
145144

145+
// Create the constant map
146146
emitc::ConstantOp bufferMap = rewriter.create<emitc::ConstantOp>(
147147
classOp.getLoc(), mapType,
148148
emitc::OpaqueAttr::get(context, mapInitializer));
149149

150-
mlir::Value nameArg = getBufferFunc.getArgument(0);
151-
emitc::CallOpaqueOp it = rewriter.create<emitc::CallOpaqueOp>(
152-
classOp.getLoc(), iteratorType, rewriter.getStringAttr("find"),
153-
mlir::ValueRange{bufferMap.getResult(), nameArg});
154-
emitc::CallOpaqueOp endIt = rewriter.create<emitc::CallOpaqueOp>(
155-
classOp.getLoc(), iteratorType, rewriter.getStringAttr("end"),
156-
bufferMap.getResult());
157-
emitc::CallOpaqueOp isEnd = rewriter.create<emitc::CallOpaqueOp>(
158-
classOp.getLoc(), rewriter.getI1Type(),
159-
"operator==", mlir::ValueRange{it.getResult(0), endIt.getResult(0)});
160-
emitc::ConstantOp nullPtr = rewriter.create<emitc::ConstantOp>(
161-
classOp.getLoc(), charType, emitc::OpaqueAttr::get(context, "nullptr"));
162-
emitc::CallOpaqueOp second = rewriter.create<emitc::CallOpaqueOp>(
163-
classOp.getLoc(), charType, "second", it.getResult(0));
164-
165-
emitc::ConditionalOp result = rewriter.create<emitc::ConditionalOp>(
166-
classOp.getLoc(), charType, isEnd.getResult(0), nullPtr.getResult(),
167-
second.getResult(0));
168-
169-
rewriter.create<emitc::ReturnOp>(classOp.getLoc(), result.getResult());
150+
rewriter.create<mlir::emitc::ReturnOp>(classOp.getLoc(),
151+
bufferMap.getResult());
170152

171153
return success();
172154
}

0 commit comments

Comments
 (0)