@@ -43,6 +43,8 @@ namespace {
4343
4444constexpr std::string_view kMangledDialect = " stable_mosaic." ;
4545constexpr StringRef kVersionAttrName = " stable_mosaic.version" ;
46+ // When this is bumped, we should file a TODO to update the forward-compatible
47+ // version in tpu_custom_call.py in a month!
4648constexpr int kVersion = 3 ;
4749
4850StringRef mangle (StringRef name, std::string* storage) {
@@ -63,7 +65,7 @@ std::optional<StringRef> demangle(StringRef name) {
6365
6466using rule_type = std::function<LogicalResult(Operation*, int )>;
6567
66- LogicalResult enqueue_dma_rule (Operation* op, int version) {
68+ LogicalResult enqueue_dma_upgrade (Operation* op, int version) {
6769 // Added AttrSizedOperandSegments and core_id in version 2.
6870 if (version < 2 ) {
6971 if (op->getNumOperands () == 3 ) { // Local DMA.
@@ -84,17 +86,21 @@ LogicalResult enqueue_dma_rule(Operation* op, int version) {
8486 return success ();
8587}
8688
87- LogicalResult semaphore_signal_rule (Operation* op, int version) {
89+ LogicalResult enqueue_dma_downgrade (Operation* op, int version) {
90+ if (version < 2 ) {
91+ return op->emitError (" Downgrade to version " ) << version << " unsupported" ;
92+ }
93+ return success ();
94+ }
95+
96+ LogicalResult semaphore_signal_upgrade (Operation* op, int version) {
8897 // Added AttrSizedOperandSegments and core_id in version 2.
8998 if (version < 2 ) {
9099 if (op->getNumOperands () == 2 ) { // Local signal.
91100 op->setAttr (OpTrait::AttrSizedOperandSegments<
92101 EnqueueDMAOp>::getOperandSegmentSizeAttr (),
93102 mlir::DenseI32ArrayAttr::get (op->getContext (), {1 , 1 , 0 , 0 }));
94103 } else if (op->getNumOperands () == 3 ) { // Remote signal.
95- // Hardcoding that one optional value is device_id, not core_id. This
96- // could misinterpret sem_signals where core_id is specified, but
97- // device_id isn't.
98104 op->setAttr (OpTrait::AttrSizedOperandSegments<
99105 EnqueueDMAOp>::getOperandSegmentSizeAttr (),
100106 mlir::DenseI32ArrayAttr::get (op->getContext (), {1 , 1 , 1 , 0 }));
@@ -105,7 +111,25 @@ LogicalResult semaphore_signal_rule(Operation* op, int version) {
105111 return success ();
106112}
107113
108- LogicalResult vector_multi_dim_reduce_rule (Operation* op, int version) {
114+ LogicalResult semaphore_signal_downgrade (Operation* op, int version) {
115+ if (version < 2 ) {
116+ auto operands = op->getAttrOfType <mlir::DenseI32ArrayAttr>(
117+ OpTrait::AttrSizedOperandSegments<
118+ EnqueueDMAOp>::getOperandSegmentSizeAttr ());
119+ if (!operands || operands.size () != 4 ) {
120+ return op->emitError (" Missing or invalid AttrSizedOperandSegments" );
121+ }
122+ if (operands[3 ]) {
123+ return op->emitError (" Downgrade to version " )
124+ << version << " impossible: core_id is set" ;
125+ }
126+ op->removeAttr (OpTrait::AttrSizedOperandSegments<
127+ EnqueueDMAOp>::getOperandSegmentSizeAttr ());
128+ }
129+ return success ();
130+ }
131+
132+ LogicalResult vector_multi_dim_reduce_upgrade (Operation* op, int version) {
109133 // Changed reductions_dims from ArrayAttr of IntegerAttrs to DenseI64ArrayAttr
110134 // in version 3.
111135 if (version < 3 ) {
@@ -133,21 +157,49 @@ LogicalResult vector_multi_dim_reduce_rule(Operation* op, int version) {
133157 return success ();
134158}
135159
160+ LogicalResult vector_multi_dim_reduce_downgrade (Operation* op, int version) {
161+ if (version < 3 ) {
162+ return op->emitError (" Downgrade to version " ) << version << " unsupported" ;
163+ }
164+ return success ();
165+ }
166+
136167const llvm::StringMap<rule_type>& upgrade_rules () {
137168 static auto rules = new llvm::StringMap<rule_type>{
138- {EnqueueDMAOp::getOperationName (), enqueue_dma_rule },
139- {SemaphoreSignalOp::getOperationName (), semaphore_signal_rule },
169+ {EnqueueDMAOp::getOperationName (), enqueue_dma_upgrade },
170+ {SemaphoreSignalOp::getOperationName (), semaphore_signal_upgrade },
140171 {vector::MultiDimReductionOp::getOperationName (),
141- vector_multi_dim_reduce_rule }
172+ vector_multi_dim_reduce_upgrade }
142173 };
143174 return *rules;
144175}
145176
177+ const llvm::StringMap<rule_type>& downgrade_rules () {
178+ static auto rules = new llvm::StringMap<rule_type>{
179+ {EnqueueDMAOp::getOperationName (), enqueue_dma_downgrade},
180+ {SemaphoreSignalOp::getOperationName (), semaphore_signal_downgrade},
181+ {vector::MultiDimReductionOp::getOperationName (),
182+ vector_multi_dim_reduce_downgrade}};
183+ return *rules;
184+ }
185+
146186struct MosaicSerdePass : public impl ::MosaicSerdePassBase<MosaicSerdePass> {
147187 using Base::Base;
148188
149189 void runOnOperation () override {
150190 ModuleOp module = getOperation ();
191+ if (!serialize.hasValue ()) {
192+ module .emitError (" serialize option must be specified" );
193+ return signalPassFailure ();
194+ }
195+ int serialize_version =
196+ target_version.hasValue () ? target_version : kVersion ;
197+ if (serialize && serialize_version > kVersion ) {
198+ module .emitError (" The highest supported version is " )
199+ << kVersion << " but requested serialization at version "
200+ << serialize_version;
201+ return signalPassFailure ();
202+ }
151203 if (serialize && !module ->getContext ()->allowsUnregisteredDialects ()) {
152204 module .emitError () << " Cannot serialize within a context that does not "
153205 " allow unregistered dialects." ;
@@ -159,7 +211,7 @@ struct MosaicSerdePass : public impl::MosaicSerdePassBase<MosaicSerdePass> {
159211 module ->setAttr (
160212 kVersionAttrName ,
161213 IntegerAttr::get (IntegerType::get (module ->getContext (), 64 ),
162- kVersion ));
214+ serialize_version ));
163215 } else {
164216 IntegerAttr version_attr =
165217 module ->getAttrOfType <IntegerAttr>(kVersionAttrName );
@@ -178,7 +230,7 @@ struct MosaicSerdePass : public impl::MosaicSerdePassBase<MosaicSerdePass> {
178230 module ->removeAttr (kVersionAttrName );
179231 }
180232 std::string name_storage;
181- auto result = module .walk ([this , &name_storage, version ](Operation* op) {
233+ auto result = module .walk ([& ](Operation* op) {
182234 if (isa<ModuleOp>(op)) { // Don't mangle the ModuleOp itself.
183235 return WalkResult::advance ();
184236 }
@@ -210,6 +262,16 @@ struct MosaicSerdePass : public impl::MosaicSerdePassBase<MosaicSerdePass> {
210262 auto new_op = Operation::create (
211263 op->getLoc (), *new_name, op->getResultTypes (), op->getOperands (),
212264 op->getAttrs (), nullptr , op->getSuccessors (), op->getRegions ());
265+ // Downgrade the op to the target version, if needed.
266+ if (serialize && kVersion != serialize_version) {
267+ if (const auto rule =
268+ downgrade_rules ().find (op->getName ().getStringRef ());
269+ rule != downgrade_rules ().end ()) {
270+ if (rule->second (new_op, serialize_version).failed ()) {
271+ return WalkResult::interrupt ();
272+ }
273+ }
274+ }
213275 op->getBlock ()->getOperations ().insertAfter (Block::iterator (op), new_op);
214276 op->replaceAllUsesWith (new_op->getResults ());
215277 op->erase ();
0 commit comments