Skip to content

Commit da210bd

Browse files
committed
Add an argument
1 parent 58012dd commit da210bd

File tree

3 files changed

+101
-57
lines changed

3 files changed

+101
-57
lines changed

mlir/include/mlir/Dialect/EmitC/Transforms/Passes.td

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,17 @@ def AddReflectionMapPass : Pass<"add-reflection-map"> {
6666
Example transformation:
6767
```mlir
6868
emitc.class @MyClass {
69-
emitc.field @fieldName0 : !emitc.array<1xf32> {tf_saved_model.index_path = ["another_feature"]}
70-
emitc.field @fieldName1 : !emitc.array<1xf32> {tf_saved_model.index_path = ["some_feature"]}
69+
emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.field_ref = ["another_feature"]}
70+
emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.field_ref = ["some_feature"]}
7171
emitc.func @execute() { ... }
7272
}
7373
```
7474

7575
Becomes:
7676
```mlir
7777
emitc.class @MyClass {
78-
emitc.field @fieldName0 : !emitc.array<1xf32> {tf_saved_model.index_path = ["another_feature"]}
79-
emitc.field @fieldName1 : !emitc.array<1xf32> {tf_saved_model.index_path = ["some_feature"]}
78+
emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.field_ref = ["another_feature"]}
79+
emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.field_ref = ["some_feature"]}
8080
emitc.func @getBufferForName(%name : !emitc.opaque<"std::string_view">) -> !emitc.opaque<"char*"> {
8181
%map = "emitc.constant"(){value = #emitc.opaque<"{"another_feature", reinterpret_cast<char*>(&another_feature)}, {"some_feature", reinterpret_cast<char*>(&some_feature)}">} : () -> !emitc.opaque<"std::map<std::string, char*>">
8282
return %null : !emitc.opaque<"char*">
@@ -86,6 +86,10 @@ def AddReflectionMapPass : Pass<"add-reflection-map"> {
8686
```
8787
}];
8888
let dependentDialects = ["mlir::emitc::EmitCDialect"];
89+
let options = [Option<"namedAttribute", "named-attribute", "std::string",
90+
/*default=*/"",
91+
"Attribute key used to extract field names from fields "
92+
"dictionary attributes">];
8993
}
9094

9195
#endif // MLIR_DIALECT_EMITC_TRANSFORMS_PASSES
Lines changed: 38 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,10 @@
1-
/* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2-
Licensed under the Apache License, Version 2.0 (the "License");
3-
you may not use this file except in compliance with the License.
4-
You may obtain a copy of the License at
5-
http://www.apache.org/licenses/LICENSE-2.0
6-
Unless required by applicable law or agreed to in writing, software
7-
distributed under the License is distributed on an "AS IS" BASIS,
8-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9-
See the License for the specific language governing permissions and
10-
limitations under the License.
11-
==============================================================================*/
12-
1+
//===- AddReflectionMap.cpp - Add a reflection map to a class -------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
138
#include "mlir/Dialect/EmitC/IR/EmitC.h"
149
#include "mlir/Dialect/EmitC/Transforms/Passes.h"
1510
#include "mlir/Dialect/EmitC/Transforms/Transforms.h"
@@ -35,7 +30,7 @@ class AddReflectionMapPass
3530
Operation *rootOp = getOperation();
3631

3732
RewritePatternSet patterns(&getContext());
38-
populateAddReflectionMapPatterns(patterns);
33+
populateAddReflectionMapPatterns(patterns, namedAttribute);
3934

4035
walkAndApplyPatterns(rootOp, std::move(patterns));
4136
}
@@ -47,8 +42,8 @@ class AddReflectionMapPass
4742

4843
class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
4944
public:
50-
AddReflectionMapClass(MLIRContext *context)
51-
: OpRewritePattern<emitc::ClassOp>(context) {}
45+
AddReflectionMapClass(MLIRContext *context, StringRef attrName)
46+
: OpRewritePattern<emitc::ClassOp>(context), attributeName(attrName) {}
5247

5348
LogicalResult matchAndRewrite(mlir::emitc::ClassOp classOp,
5449
PatternRewriter &rewriter) const override {
@@ -73,23 +68,23 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
7368
rewriter.setInsertionPointToStart(funcBody);
7469

7570
// Collect all field names
76-
SmallVector<std::string> fieldNames;
71+
std::vector<std::pair<std::string, std::string>> fieldNames;
7772
classOp.walk([&](mlir::emitc::FieldOp fieldOp) {
7873
if (mlir::Attribute attrsAttr =
7974
fieldOp->getAttrDictionary().get("attrs")) {
8075
if (DictionaryAttr innerDictAttr =
8176
dyn_cast<mlir::DictionaryAttr>(attrsAttr)) {
82-
auto indexPathAttr =
83-
innerDictAttr.getNamed("tf_saved_model.index_path");
77+
auto indexPathAttr = innerDictAttr.getNamed(attributeName);
8478
ArrayAttr arrayAttr =
8579
dyn_cast<mlir::ArrayAttr>(indexPathAttr->getValue());
8680
if (!arrayAttr.empty()) {
8781
StringAttr stringAttr = dyn_cast<mlir::StringAttr>(arrayAttr[0]);
8882
std::string indexPath = stringAttr.getValue().str();
89-
fieldNames.push_back(indexPath);
83+
fieldNames.emplace_back(indexPath, fieldOp.getName().str());
9084
}
9185
if (arrayAttr.size() > 1) {
92-
fieldOp.emitError() << "tf_saved_model.index_path attribute must "
86+
fieldOp.emitError() << attributeName
87+
<< " attribute must "
9388
"contain at most one value, but found "
9489
<< arrayAttr.size() << " values.";
9590
return;
@@ -98,64 +93,54 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
9893
}
9994
});
10095

101-
std::string mapInitializer = "{{";
96+
std::string mapInitializer = "{ ";
10297
for (size_t i = 0; i < fieldNames.size(); ++i) {
103-
mapInitializer += "\"" + fieldNames[i] + "\", " +
104-
"reinterpret_cast<char*>(&" + fieldNames[i] + ")",
105-
mapInitializer += "}";
98+
mapInitializer += " { \"" + fieldNames[i].first + "\", " +
99+
"reinterpret_cast<char*>(&" + fieldNames[i].second +
100+
")",
101+
mapInitializer += " }";
106102
if (i < fieldNames.size() - 1)
107-
mapInitializer += ", {";
103+
mapInitializer += ", ";
108104
}
109-
mapInitializer += "}";
105+
mapInitializer += " }";
110106

111-
auto iteratorType = mlir::emitc::OpaqueType::get(
107+
emitc::OpaqueType iteratorType = mlir::emitc::OpaqueType::get(
112108
context, "std::map<std::string, char*>::const_iterator");
113-
auto boolType = rewriter.getI1Type();
114-
// 5. Create the constant map
115-
auto bufferMap = rewriter.create<emitc::ConstantOp>(
109+
110+
emitc::ConstantOp bufferMap = rewriter.create<emitc::ConstantOp>(
116111
classOp.getLoc(), mapType,
117112
emitc::OpaqueAttr::get(context, mapInitializer));
118113

119-
// 6. Get the function argument
120114
mlir::Value nameArg = getBufferFunc.getArgument(0);
121-
122-
// 7. Create the find call
123-
auto it = rewriter.create<emitc::CallOpaqueOp>(
115+
emitc::CallOpaqueOp it = rewriter.create<emitc::CallOpaqueOp>(
124116
classOp.getLoc(), iteratorType, rewriter.getStringAttr("find"),
125117
mlir::ValueRange{bufferMap.getResult(), nameArg});
126-
127-
// 8. Create the end call
128-
auto endIt = rewriter.create<emitc::CallOpaqueOp>(
118+
emitc::CallOpaqueOp endIt = rewriter.create<emitc::CallOpaqueOp>(
129119
classOp.getLoc(), iteratorType, rewriter.getStringAttr("end"),
130120
bufferMap.getResult());
131-
132-
// 9. Create the operator== call
133-
auto isEnd = rewriter.create<emitc::CallOpaqueOp>(
134-
classOp.getLoc(), boolType,
121+
emitc::CallOpaqueOp isEnd = rewriter.create<emitc::CallOpaqueOp>(
122+
classOp.getLoc(), rewriter.getI1Type(),
135123
"operator==", mlir::ValueRange{it.getResult(0), endIt.getResult(0)});
136-
137-
// 10. Create the nullptr constant
138-
auto nullPtr = rewriter.create<emitc::ConstantOp>(
124+
emitc::ConstantOp nullPtr = rewriter.create<emitc::ConstantOp>(
139125
classOp.getLoc(), charPtrType,
140126
emitc::OpaqueAttr::get(context, "nullptr"));
141-
142-
// 11. Create the second call
143-
auto second = rewriter.create<emitc::CallOpaqueOp>(
127+
emitc::CallOpaqueOp second = rewriter.create<emitc::CallOpaqueOp>(
144128
classOp.getLoc(), charPtrType, "second", it.getResult(0));
145129

146-
// 12. Create the conditional
147-
auto result = rewriter.create<emitc::ConditionalOp>(
130+
emitc::ConditionalOp result = rewriter.create<emitc::ConditionalOp>(
148131
classOp.getLoc(), charPtrType, isEnd.getResult(0), nullPtr.getResult(),
149132
second.getResult(0));
150133

151-
// 13. Create return
152134
rewriter.create<emitc::ReturnOp>(classOp.getLoc(), result.getResult());
153135

154136
return success();
155137
}
138+
139+
private:
140+
StringRef attributeName;
156141
};
157142

158-
void mlir::emitc::populateAddReflectionMapPatterns(
159-
RewritePatternSet &patterns) {
160-
patterns.add<AddReflectionMapClass>(patterns.getContext());
143+
void mlir::emitc::populateAddReflectionMapPatterns(RewritePatternSet &patterns,
144+
StringRef namedAttribute) {
145+
patterns.add<AddReflectionMapClass>(patterns.getContext(), namedAttribute);
161146
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// RUN: mlir-opt --add-reflection-map="named-attribute=emitc.field_ref" %s | FileCheck %s
2+
3+
emitc.class @mainClass {
4+
emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.field_ref = ["another_feature"]}
5+
emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.field_ref = ["some_feature"]}
6+
emitc.field @fieldName2 : !emitc.array<1xf32> {emitc.field_ref = ["output_0"]}
7+
emitc.func @execute() {
8+
%0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
9+
%1 = get_field @fieldName0 : !emitc.array<1xf32>
10+
%2 = get_field @fieldName1 : !emitc.array<1xf32>
11+
%3 = get_field @fieldName2 : !emitc.array<1xf32>
12+
%4 = subscript %2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
13+
%5 = load %4 : <f32>
14+
%6 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
15+
%7 = load %6 : <f32>
16+
%8 = add %5, %7 : (f32, f32) -> f32
17+
%9 = subscript %3[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
18+
assign %8 : f32 to %9 : <f32>
19+
return
20+
}
21+
}
22+
23+
// CHECK: module {
24+
// CHECK-NEXT: emitc.class @mainClass {
25+
// CHECK-NEXT: emitc.field @fieldName0 : !emitc.array<1xf32> {emitc.field_ref = ["another_feature"]}
26+
// CHECK-NEXT: emitc.field @fieldName1 : !emitc.array<1xf32> {emitc.field_ref = ["some_feature"]}
27+
// CHECK-NEXT: emitc.field @fieldName2 : !emitc.array<1xf32> {emitc.field_ref = ["output_0"]}
28+
// CHECK-NEXT: emitc.func @getBufferForName(%arg0: !emitc.opaque<"std::string_view">) -> !emitc.opaque<"char"> {
29+
// CHECK-NEXT: %0 = "emitc.constant"() <{value = #emitc.opaque<"{ { \22another_feature\22, reinterpret_cast<char*>(&fieldName0) }, { \22some_feature\22, reinterpret_cast<char*>(&fieldName1) }, { \22output_0\22, reinterpret_cast<char*>(&fieldName2) } }">}> : () -> !emitc.opaque<"const std::map<std::string, char*>">
30+
// CHECK-NEXT: %1 = call_opaque "find"(%0, %arg0) : (!emitc.opaque<"const std::map<std::string, char*>">, !emitc.opaque<"std::string_view">) -> !emitc.opaque<"std::map<std::string, char*>::const_iterator">
31+
// CHECK-NEXT: %2 = call_opaque "end"(%0) : (!emitc.opaque<"const std::map<std::string, char*>">) -> !emitc.opaque<"std::map<std::string, char*>::const_iterator">
32+
// CHECK-NEXT: %3 = call_opaque "operator=="(%1, %2) : (!emitc.opaque<"std::map<std::string, char*>::const_iterator">, !emitc.opaque<"std::map<std::string, char*>::const_iterator">) -> i1
33+
// CHECK-NEXT: %4 = "emitc.constant"() <{value = #emitc.opaque<"nullptr">}> : () -> !emitc.opaque<"char">
34+
// CHECK-NEXT: %5 = call_opaque "second"(%1) : (!emitc.opaque<"std::map<std::string, char*>::const_iterator">) -> !emitc.opaque<"char">
35+
// CHECK-NEXT: %6 = conditional %3, %4, %5 : !emitc.opaque<"char">
36+
// CHECK-NEXT: return %6 : !emitc.opaque<"char">
37+
// CHECK-NEXT: }
38+
// CHECK-NEXT: emitc.func @execute() {
39+
// CHECK-NEXT: %0 = "emitc.constant"() <{value = 0 : index}> : () -> !emitc.size_t
40+
// CHECK-NEXT: %1 = get_field @fieldName0 : !emitc.array<1xf32>
41+
// CHECK-NEXT: %2 = get_field @fieldName1 : !emitc.array<1xf32>
42+
// CHECK-NEXT: %3 = get_field @fieldName2 : !emitc.array<1xf32>
43+
// CHECK-NEXT: %4 = subscript %2[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
44+
// CHECK-NEXT: %5 = load %4 : <f32>
45+
// CHECK-NEXT: %6 = subscript %1[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
46+
// CHECK-NEXT: %7 = load %6 : <f32>
47+
// CHECK-NEXT: %8 = add %5, %7 : (f32, f32) -> f32
48+
// CHECK-NEXT: %9 = subscript %3[%0] : (!emitc.array<1xf32>, !emitc.size_t) -> !emitc.lvalue<f32>
49+
// CHECK-NEXT: assign %8 : f32 to %9 : <f32>
50+
// CHECK-NEXT: return
51+
// CHECK-NEXT: }
52+
// CHECK-NEXT: }
53+
// CHECK-NEXT: }
54+
55+

0 commit comments

Comments
 (0)