Skip to content

Commit afe86ff

Browse files
committed
Specify the pass reqs
1 parent da210bd commit afe86ff

File tree

3 files changed

+66
-22
lines changed

3 files changed

+66
-22
lines changed

mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,33 +53,41 @@ def WrapFuncInClassPass : Pass<"wrap-emitc-func-in-class"> {
5353
let dependentDialects = ["emitc::EmitCDialect"];
5454
}
5555

56-
def AddReflectionMapPass : Pass<"add-reflection-map"> {
56+
def AddReflectionMapPass : Pass<"add-reflection-map", "ModuleOp"> {
5757
let summary =
5858
"Add a reflection map function to EmitC classes for runtime field lookup";
5959
let description = [{
6060
This pass adds a `getBufferForName` function to EmitC classes that enables
6161
runtime lookup of field buffers by their string names.
62-
This enables runtime introspection and dynamic access to class fields by name,
63-
which is useful for interfacing with external systems that need to access
64-
tensors/buffers by their semantic names.
62+
This would require that the class has fields with attributes and a function named `execute`.
63+
The `fieldop` attribute is expected to be a dictionary where:
64+
- The keys are `namedAttribute`.
65+
- The values are arrays containing a single string attribute.
66+
67+
68+
Example:
6569

66-
Example transformation:
6770
```mlir
6871
emitc.class @MyClass {
6972
emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.field_ref = ["another_feature"]}
7073
emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.field_ref = ["some_feature"]}
7174
emitc.func @execute() { ... }
7275
}
73-
```
7476

75-
Becomes:
76-
```mlir
77+
// becomes:
78+
7779
emitc.class @MyClass {
7880
emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.field_ref = ["another_feature"]}
7981
emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.field_ref = ["some_feature"]}
8082
emitc.func @getBufferForName(%name : !emitc.opaque<"std::string_view">) -> !emitc.opaque<"char*"> {
81-
%map = "emitc.constant"(){value = #emitc.opaque<"{"another_feature", reinterpret_cast<char*>(&another_feature)}, {"some_feature", reinterpret_cast<char*>(&some_feature)}">} : () -> !emitc.opaque<"std::map<std::string, char*>">
82-
return %null : !emitc.opaque<"char*">
83+
%0 = "emitc.constant"() <{value = #emitc.opaque<"{ { \22another_feature\22, reinterpret_cast<char*>(&fieldName0) }, { \22some_feature\22, reinterpret_cast<char*>(&fieldName1) } }">}> : () -> !emitc.opaque<"const std::map<std::string, char*>">
84+
%1 = call_opaque "find"(%0, %arg0) : (!emitc.opaque<"const std::map<std::string, char*>">, !emitc.opaque<"std::string_view">) -> !emitc.opaque<"std::map<std::string, char*>::const_iterator">
85+
%2 = call_opaque "end"(%0) : (!emitc.opaque<"const std::map<std::string, char*>">) -> !emitc.opaque<"std::map<std::string, char*>::const_iterator">
86+
%3 = call_opaque "operator=="(%1, %2) : (!emitc.opaque<"std::map<std::string, char*>::const_iterator">, !emitc.opaque<"std::map<std::string, char*>::const_iterator">) -> i1
87+
%4 = "emitc.constant"() <{value = #emitc.opaque<"nullptr">}> : () -> !emitc.opaque<"char">
88+
%5 = call_opaque "second"(%1) : (!emitc.opaque<"std::map<std::string, char*>::const_iterator">) -> !emitc.opaque<"char">
89+
%6 = conditional %3, %4, %5 : !emitc.opaque<"char">
90+
return %6 : !emitc.opaque<"char">
8391
}
8492
emitc.func @execute() { ... }
8593
}

mlir/lib/Dialect/EmitC/Transforms/AddReflectionMap.cpp

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,48 @@ namespace emitc {
2323
#include "mlir/Dialect/EmitC/Transforms/Passes.h.inc"
2424

2525
namespace {
26+
constexpr const char *kMapLibraryHeader = "map";
27+
constexpr const char *kStringLibraryHeader = "string";
2628
class AddReflectionMapPass
2729
: public impl::AddReflectionMapPassBase<AddReflectionMapPass> {
2830
using AddReflectionMapPassBase::AddReflectionMapPassBase;
2931
void runOnOperation() override {
30-
Operation *rootOp = getOperation();
32+
mlir::ModuleOp module = getOperation();
3133

3234
RewritePatternSet patterns(&getContext());
3335
populateAddReflectionMapPatterns(patterns, namedAttribute);
3436

35-
walkAndApplyPatterns(rootOp, std::move(patterns));
37+
walkAndApplyPatterns(module, std::move(patterns));
38+
bool hasMap = false;
39+
bool hasString = false;
40+
for (auto &op : *module.getBody()) {
41+
emitc::IncludeOp includeOp = llvm::dyn_cast<mlir::emitc::IncludeOp>(op);
42+
if (!includeOp)
43+
continue;
44+
if (includeOp.getIsStandardInclude()) {
45+
if (includeOp.getInclude() == kMapLibraryHeader)
46+
hasMap = true;
47+
if (includeOp.getInclude() == kStringLibraryHeader)
48+
hasString = true;
49+
}
50+
}
51+
52+
if (hasMap && hasString)
53+
return;
54+
55+
mlir::OpBuilder builder(module.getBody(), module.getBody()->begin());
56+
if (!hasMap) {
57+
StringAttr includeAttr = builder.getStringAttr(kMapLibraryHeader);
58+
builder.create<mlir::emitc::IncludeOp>(
59+
module.getLoc(), includeAttr,
60+
/*is_standard_include=*/builder.getUnitAttr());
61+
}
62+
if (!hasString) {
63+
StringAttr includeAttr = builder.getStringAttr(kStringLibraryHeader);
64+
builder.create<emitc::IncludeOp>(
65+
module.getLoc(), includeAttr,
66+
/*is_standard_include=*/builder.getUnitAttr());
67+
}
3668
}
3769
};
3870

@@ -50,16 +82,20 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
5082
mlir::MLIRContext *context = rewriter.getContext();
5183
emitc::OpaqueType stringViewType =
5284
mlir::emitc::OpaqueType::get(rewriter.getContext(), "std::string_view");
53-
emitc::OpaqueType charPtrType =
85+
emitc::OpaqueType charType =
5486
mlir::emitc::OpaqueType::get(rewriter.getContext(), "char");
5587
emitc::OpaqueType mapType = mlir::emitc::OpaqueType::get(
5688
rewriter.getContext(), "const std::map<std::string, char*>");
5789

5890
FunctionType funcType =
59-
rewriter.getFunctionType({stringViewType}, {charPtrType});
91+
rewriter.getFunctionType({stringViewType}, {charType});
6092
emitc::FuncOp executeFunc =
6193
classOp.lookupSymbol<mlir::emitc::FuncOp>("execute");
62-
rewriter.setInsertionPoint(executeFunc);
94+
if (executeFunc)
95+
rewriter.setInsertionPoint(executeFunc);
96+
else
97+
classOp.emitError() << "ClassOp must contain a function named 'execute' "
98+
"to add reflection map";
6399

64100
emitc::FuncOp getBufferFunc = rewriter.create<mlir::emitc::FuncOp>(
65101
classOp.getLoc(), "getBufferForName", funcType);
@@ -74,9 +110,8 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
74110
fieldOp->getAttrDictionary().get("attrs")) {
75111
if (DictionaryAttr innerDictAttr =
76112
dyn_cast<mlir::DictionaryAttr>(attrsAttr)) {
77-
auto indexPathAttr = innerDictAttr.getNamed(attributeName);
78-
ArrayAttr arrayAttr =
79-
dyn_cast<mlir::ArrayAttr>(indexPathAttr->getValue());
113+
ArrayAttr arrayAttr = dyn_cast<mlir::ArrayAttr>(
114+
innerDictAttr.getNamed(attributeName)->getValue());
80115
if (!arrayAttr.empty()) {
81116
StringAttr stringAttr = dyn_cast<mlir::StringAttr>(arrayAttr[0]);
82117
std::string indexPath = stringAttr.getValue().str();
@@ -122,13 +157,12 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
122157
classOp.getLoc(), rewriter.getI1Type(),
123158
"operator==", mlir::ValueRange{it.getResult(0), endIt.getResult(0)});
124159
emitc::ConstantOp nullPtr = rewriter.create<emitc::ConstantOp>(
125-
classOp.getLoc(), charPtrType,
126-
emitc::OpaqueAttr::get(context, "nullptr"));
160+
classOp.getLoc(), charType, emitc::OpaqueAttr::get(context, "nullptr"));
127161
emitc::CallOpaqueOp second = rewriter.create<emitc::CallOpaqueOp>(
128-
classOp.getLoc(), charPtrType, "second", it.getResult(0));
162+
classOp.getLoc(), charType, "second", it.getResult(0));
129163

130164
emitc::ConditionalOp result = rewriter.create<emitc::ConditionalOp>(
131-
classOp.getLoc(), charPtrType, isEnd.getResult(0), nullPtr.getResult(),
165+
classOp.getLoc(), charType, isEnd.getResult(0), nullPtr.getResult(),
132166
second.getResult(0));
133167

134168
rewriter.create<emitc::ReturnOp>(classOp.getLoc(), result.getResult());

mlir/test/Dialect/EmitC/add_reflection_map.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ emitc.class @mainClass {
2121
}
2222

2323
// CHECK: module {
24+
// CHECK-NEXT: emitc.include <"map">
25+
// CHECK-NEXT: emitc.include <"string">
2426
// CHECK-NEXT: emitc.class @mainClass {
2527
// CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.field_ref = ["another_feature"]}
2628
// CHECK-NEXT: emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.field_ref = ["some_feature"]}

0 commit comments

Comments
 (0)