@@ -13,21 +13,24 @@ See the License for the specific language governing permissions and
1313limitations under the License.
1414==============================================================================*/
1515
16- // We need to keep some extra headers for the code in tpu_passes.h.inc.
16+ # include " jaxlib/mosaic/dialect/tpu/transforms/serde.h "
1717
18- #include < memory> // IWYU pragma: keep
18+ #include < cstdint>
19+ #include < functional>
1920#include < optional>
2021#include < string>
2122#include < string_view>
23+ #include < vector>
2224
25+ #include " mlir/Dialect/Vector/IR/VectorOps.h"
2326#include " mlir/IR/BuiltinAttributes.h"
2427#include " mlir/IR/BuiltinOps.h"
25- #include " mlir/Dialect/Vector/IR/VectorOps.h"
2628#include " mlir/IR/OperationSupport.h"
2729#include " mlir/IR/Value.h"
2830#include " mlir/IR/Visitors.h"
29- #include " mlir/Pass/Pass.h" // IWYU pragma: keep
3031#include " mlir/Support/LLVM.h"
32+ #include " llvm/include/llvm/ADT/StringMap.h"
33+ #include " mlir/include/mlir/IR/Attributes.h"
3134#include " mlir/include/mlir/IR/BuiltinAttributes.h"
3235#include " mlir/include/mlir/IR/OpDefinition.h"
3336#include " mlir/include/mlir/IR/OperationSupport.h"
@@ -36,9 +39,6 @@ limitations under the License.
3639
3740namespace mlir ::tpu {
3841
39- #define GEN_PASS_DEF_MOSAICSERDEPASS
40- #include " jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc"
41-
4242namespace {
4343
4444constexpr std::string_view kMangledDialect = " stable_mosaic." ;
@@ -183,107 +183,101 @@ const llvm::StringMap<rule_type>& downgrade_rules() {
183183 return *rules;
184184}
185185
186- struct MosaicSerdePass : public impl ::MosaicSerdePassBase<MosaicSerdePass> {
187- using Base::Base;
186+ } // namespace
188187
189- void runOnOperation () override {
190- ModuleOp module = getOperation ();
191- if (!serialize.hasValue ()) {
192- module .emitError (" serialize option must be specified" );
193- return signalPassFailure ();
194- }
195- int serialize_version =
196- target_version.hasValue () ? target_version : kVersion ;
197- if (serialize && serialize_version > kVersion ) {
198- module .emitError (" The highest supported version is " )
199- << kVersion << " but requested serialization at version "
200- << serialize_version;
201- return signalPassFailure ();
188+ void MosaicSerdePass::runOnOperation () {
189+ ModuleOp module = getOperation ();
190+ if (!serialize.hasValue ()) {
191+ module .emitError (" serialize option must be specified" );
192+ return signalPassFailure ();
193+ }
194+ int serialize_version = target_version.hasValue () ? target_version : kVersion ;
195+ if (serialize && serialize_version > kVersion ) {
196+ module .emitError (" The highest supported version is " )
197+ << kVersion << " but requested serialization at version "
198+ << serialize_version;
199+ return signalPassFailure ();
200+ }
201+ if (serialize && !module ->getContext ()->allowsUnregisteredDialects ()) {
202+ module .emitError () << " Cannot serialize within a context that does not "
203+ " allow unregistered dialects." ;
204+ signalPassFailure ();
205+ return ;
206+ }
207+ int version = kVersion ;
208+ if (serialize) {
209+ module ->setAttr (kVersionAttrName ,
210+ IntegerAttr::get (IntegerType::get (module ->getContext (), 64 ),
211+ serialize_version));
212+ } else {
213+ IntegerAttr version_attr =
214+ module ->getAttrOfType <IntegerAttr>(kVersionAttrName );
215+ if (!version_attr) {
216+ module ->emitError (" Missing or invalid Mosaic version attribute" );
217+ signalPassFailure ();
218+ return ;
202219 }
203- if (serialize && ! module -> getContext ()-> allowsUnregisteredDialects () ) {
204- module . emitError () << " Cannot serialize within a context that does not "
205- " allow unregistered dialects. " ;
220+ if (version_attr. getInt () > kVersion ) {
221+ module -> emitError (" Unsupported Mosaic version: expected <= " )
222+ << kVersion << " but got " << version_attr. getInt () ;
206223 signalPassFailure ();
207224 return ;
208225 }
209- int version = kVersion ;
226+ version = version_attr.getInt ();
227+ module ->removeAttr (kVersionAttrName );
228+ }
229+ std::string name_storage;
230+ auto result = module .walk ([&](Operation* op) {
231+ if (isa<ModuleOp>(op)) { // Don't mangle the ModuleOp itself.
232+ return WalkResult::advance ();
233+ }
234+ std::optional<OperationName> new_name;
210235 if (serialize) {
211- module ->setAttr (
212- kVersionAttrName ,
213- IntegerAttr::get (IntegerType::get (module ->getContext (), 64 ),
214- serialize_version));
236+ auto new_name_str = mangle (op->getName ().getStringRef (), &name_storage);
237+ new_name = OperationName (new_name_str, op->getContext ());
215238 } else {
216- IntegerAttr version_attr =
217- module ->getAttrOfType <IntegerAttr>(kVersionAttrName );
218- if (!version_attr) {
219- module ->emitError (" Missing or invalid Mosaic version attribute" );
220- signalPassFailure ();
221- return ;
222- }
223- if (version_attr.getInt () > kVersion ) {
224- module ->emitError (" Unsupported Mosaic version: expected <= " )
225- << kVersion << " but got " << version_attr.getInt ();
226- signalPassFailure ();
227- return ;
228- }
229- version = version_attr.getInt ();
230- module ->removeAttr (kVersionAttrName );
231- }
232- std::string name_storage;
233- auto result = module .walk ([&](Operation* op) {
234- if (isa<ModuleOp>(op)) { // Don't mangle the ModuleOp itself.
235- return WalkResult::advance ();
236- }
237- std::optional<OperationName> new_name;
238- if (serialize) {
239- auto new_name_str = mangle (op->getName ().getStringRef (), &name_storage);
240- new_name = OperationName (new_name_str, op->getContext ());
241- } else {
242- if (auto demangled = demangle (op->getName ().getStringRef ())) {
243- auto new_name_str = *demangled;
244- if (auto registered = RegisteredOperationName::lookup (
245- new_name_str, op->getContext ())) {
246- new_name = *registered;
247- } else {
248- new_name = OperationName (new_name_str, op->getContext ());
249- }
239+ if (auto demangled = demangle (op->getName ().getStringRef ())) {
240+ auto new_name_str = *demangled;
241+ if (auto registered = RegisteredOperationName::lookup (
242+ new_name_str, op->getContext ())) {
243+ new_name = *registered;
250244 } else {
251- op->emitError (" Operation not in a serialized form" );
252- return WalkResult::interrupt ();
245+ new_name = OperationName (new_name_str, op->getContext ());
253246 }
254- // Upgrade the op to the current version, if needed.
255- if (const auto rule = upgrade_rules ().find (new_name->getStringRef ());
256- rule != upgrade_rules ().end ()) {
257- if (rule->second (op, version).failed ()) {
258- return WalkResult::interrupt ();
259- }
247+ } else {
248+ op->emitError (" Operation not in a serialized form" );
249+ return WalkResult::interrupt ();
250+ }
251+ // Upgrade the op to the current version, if needed.
252+ if (const auto rule = upgrade_rules ().find (new_name->getStringRef ());
253+ rule != upgrade_rules ().end ()) {
254+ if (rule->second (op, version).failed ()) {
255+ return WalkResult::interrupt ();
260256 }
261257 }
262- auto new_op = Operation::create (
263- op-> getLoc (), *new_name, op-> getResultTypes (), op-> getOperands (),
264- op->getAttrs (), nullptr , op->getSuccessors (), op->getRegions ());
265- // Downgrade the op to the target version, if needed.
266- if (serialize && kVersion != serialize_version) {
267- if (const auto rule =
268- downgrade_rules (). find (op-> getName (). getStringRef ());
269- rule != downgrade_rules ().end ()) {
270- if ( rule-> second (new_op, serialize_version). failed ()) {
271- return WalkResult::interrupt ();
272- }
258+ }
259+ auto new_op = Operation::create (
260+ op->getLoc (), *new_name , op->getResultTypes (), op->getOperands (),
261+ op-> getAttrs (), nullptr , op-> getSuccessors (), op-> getRegions ());
262+ // Downgrade the op to the target version, if needed.
263+ if (serialize && kVersion != serialize_version) {
264+ if ( const auto rule =
265+ downgrade_rules ().find (op-> getName (). getStringRef ());
266+ rule != downgrade_rules (). end ()) {
267+ if (rule-> second (new_op, serialize_version). failed ()) {
268+ return WalkResult::interrupt ();
273269 }
274270 }
275- op->getBlock ()->getOperations ().insertAfter (Block::iterator (op), new_op);
276- op->replaceAllUsesWith (new_op->getResults ());
277- op->erase ();
278- return WalkResult::advance ();
279- });
280- if (result.wasInterrupted ()) {
281- signalPassFailure ();
282- return ;
283271 }
272+ op->getBlock ()->getOperations ().insertAfter (Block::iterator (op), new_op);
273+ op->replaceAllUsesWith (new_op->getResults ());
274+ op->erase ();
275+ return WalkResult::advance ();
276+ });
277+ if (result.wasInterrupted ()) {
278+ signalPassFailure ();
279+ return ;
284280 }
285- };
286-
287- } // namespace
281+ }
288282
289- } // namespace mlir::tpu
283+ } // namespace mlir::tpu
0 commit comments