1- /* Copyright 2025 The TensorFlow Authors. All Rights Reserved.
2- Licensed under the Apache License, Version 2.0 (the "License");
3- you may not use this file except in compliance with the License.
4- You may obtain a copy of the License at
5- http://www.apache.org/licenses/LICENSE-2.0
6- Unless required by applicable law or agreed to in writing, software
7- distributed under the License is distributed on an "AS IS" BASIS,
8- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9- See the License for the specific language governing permissions and
10- limitations under the License.
11- ==============================================================================*/
12-
1+ // ===- AddReflectionMap.cpp - Add a reflection map to a class -------------===//
2+ //
3+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+ // See https://llvm.org/LICENSE.txt for license information.
5+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+ //
7+ // ===----------------------------------------------------------------------===//
138#include " mlir/Dialect/EmitC/IR/EmitC.h"
149#include " mlir/Dialect/EmitC/Transforms/Passes.h"
1510#include " mlir/Dialect/EmitC/Transforms/Transforms.h"
@@ -35,7 +30,7 @@ class AddReflectionMapPass
3530 Operation *rootOp = getOperation ();
3631
3732 RewritePatternSet patterns (&getContext ());
38- populateAddReflectionMapPatterns (patterns);
33+ populateAddReflectionMapPatterns (patterns, namedAttribute );
3934
4035 walkAndApplyPatterns (rootOp, std::move (patterns));
4136 }
@@ -47,8 +42,8 @@ class AddReflectionMapPass
4742
4843class AddReflectionMapClass : public OpRewritePattern <emitc::ClassOp> {
4944public:
50- AddReflectionMapClass (MLIRContext *context)
51- : OpRewritePattern<emitc::ClassOp>(context) {}
45+ AddReflectionMapClass (MLIRContext *context, StringRef attrName )
46+ : OpRewritePattern<emitc::ClassOp>(context), attributeName(attrName) {}
5247
5348 LogicalResult matchAndRewrite (mlir::emitc::ClassOp classOp,
5449 PatternRewriter &rewriter) const override {
@@ -73,23 +68,23 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
7368 rewriter.setInsertionPointToStart (funcBody);
7469
7570 // Collect all field names
76- SmallVector <std::string> fieldNames;
71+ std::vector <std::pair<std:: string, std::string> > fieldNames;
7772 classOp.walk ([&](mlir::emitc::FieldOp fieldOp) {
7873 if (mlir::Attribute attrsAttr =
7974 fieldOp->getAttrDictionary ().get (" attrs" )) {
8075 if (DictionaryAttr innerDictAttr =
8176 dyn_cast<mlir::DictionaryAttr>(attrsAttr)) {
82- auto indexPathAttr =
83- innerDictAttr.getNamed (" tf_saved_model.index_path" );
77+ auto indexPathAttr = innerDictAttr.getNamed (attributeName);
8478 ArrayAttr arrayAttr =
8579 dyn_cast<mlir::ArrayAttr>(indexPathAttr->getValue ());
8680 if (!arrayAttr.empty ()) {
8781 StringAttr stringAttr = dyn_cast<mlir::StringAttr>(arrayAttr[0 ]);
8882 std::string indexPath = stringAttr.getValue ().str ();
89- fieldNames.push_back (indexPath);
83+ fieldNames.emplace_back (indexPath, fieldOp. getName (). str () );
9084 }
9185 if (arrayAttr.size () > 1 ) {
92- fieldOp.emitError () << " tf_saved_model.index_path attribute must "
86+ fieldOp.emitError () << attributeName
87+ << " attribute must "
9388 " contain at most one value, but found "
9489 << arrayAttr.size () << " values." ;
9590 return ;
@@ -98,64 +93,54 @@ class AddReflectionMapClass : public OpRewritePattern<emitc::ClassOp> {
9893 }
9994 });
10095
101- std::string mapInitializer = " {{ " ;
96+ std::string mapInitializer = " { " ;
10297 for (size_t i = 0 ; i < fieldNames.size (); ++i) {
103- mapInitializer += " \" " + fieldNames[i] + " \" , " +
104- " reinterpret_cast<char*>(&" + fieldNames[i] + " )" ,
105- mapInitializer += " }" ;
98+ mapInitializer += " { \" " + fieldNames[i].first + " \" , " +
99+ " reinterpret_cast<char*>(&" + fieldNames[i].second +
100+ " )" ,
101+ mapInitializer += " }" ;
106102 if (i < fieldNames.size () - 1 )
107- mapInitializer += " , { " ;
103+ mapInitializer += " , " ;
108104 }
109- mapInitializer += " }" ;
105+ mapInitializer += " }" ;
110106
111- auto iteratorType = mlir::emitc::OpaqueType::get (
107+ emitc::OpaqueType iteratorType = mlir::emitc::OpaqueType::get (
112108 context, " std::map<std::string, char*>::const_iterator" );
113- auto boolType = rewriter.getI1Type ();
114- // 5. Create the constant map
115- auto bufferMap = rewriter.create <emitc::ConstantOp>(
109+
110+ emitc::ConstantOp bufferMap = rewriter.create <emitc::ConstantOp>(
116111 classOp.getLoc (), mapType,
117112 emitc::OpaqueAttr::get (context, mapInitializer));
118113
119- // 6. Get the function argument
120114 mlir::Value nameArg = getBufferFunc.getArgument (0 );
121-
122- // 7. Create the find call
123- auto it = rewriter.create <emitc::CallOpaqueOp>(
115+ emitc::CallOpaqueOp it = rewriter.create <emitc::CallOpaqueOp>(
124116 classOp.getLoc (), iteratorType, rewriter.getStringAttr (" find" ),
125117 mlir::ValueRange{bufferMap.getResult (), nameArg});
126-
127- // 8. Create the end call
128- auto endIt = rewriter.create <emitc::CallOpaqueOp>(
118+ emitc::CallOpaqueOp endIt = rewriter.create <emitc::CallOpaqueOp>(
129119 classOp.getLoc (), iteratorType, rewriter.getStringAttr (" end" ),
130120 bufferMap.getResult ());
131-
132- // 9. Create the operator== call
133- auto isEnd = rewriter.create <emitc::CallOpaqueOp>(
134- classOp.getLoc (), boolType,
121+ emitc::CallOpaqueOp isEnd = rewriter.create <emitc::CallOpaqueOp>(
122+ classOp.getLoc (), rewriter.getI1Type (),
135123 " operator==" , mlir::ValueRange{it.getResult (0 ), endIt.getResult (0 )});
136-
137- // 10. Create the nullptr constant
138- auto nullPtr = rewriter.create <emitc::ConstantOp>(
124+ emitc::ConstantOp nullPtr = rewriter.create <emitc::ConstantOp>(
139125 classOp.getLoc (), charPtrType,
140126 emitc::OpaqueAttr::get (context, " nullptr" ));
141-
142- // 11. Create the second call
143- auto second = rewriter.create <emitc::CallOpaqueOp>(
127+ emitc::CallOpaqueOp second = rewriter.create <emitc::CallOpaqueOp>(
144128 classOp.getLoc (), charPtrType, " second" , it.getResult (0 ));
145129
146- // 12. Create the conditional
147- auto result = rewriter.create <emitc::ConditionalOp>(
130+ emitc::ConditionalOp result = rewriter.create <emitc::ConditionalOp>(
148131 classOp.getLoc (), charPtrType, isEnd.getResult (0 ), nullPtr.getResult (),
149132 second.getResult (0 ));
150133
151- // 13. Create return
152134 rewriter.create <emitc::ReturnOp>(classOp.getLoc (), result.getResult ());
153135
154136 return success ();
155137 }
138+
139+ private:
140+ StringRef attributeName;
156141};
157142
158- void mlir::emitc::populateAddReflectionMapPatterns (
159- RewritePatternSet &patterns ) {
160- patterns.add <AddReflectionMapClass>(patterns.getContext ());
143+ void mlir::emitc::populateAddReflectionMapPatterns (RewritePatternSet &patterns,
144+ StringRef namedAttribute ) {
145+ patterns.add <AddReflectionMapClass>(patterns.getContext (), namedAttribute );
161146}
0 commit comments