Skip to content

Commit adb2bf6

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic TPU] Allow downgrading the IR during serialization for forward compat
This is to uphold the monthly stability promise made by jax.export. PiperOrigin-RevId: 704233290
1 parent a94474d commit adb2bf6

File tree

3 files changed

+96
-13
lines changed

3 files changed

+96
-13
lines changed

jax/_src/tpu_custom_call.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@
6262
help="Allow hlo dialects in Mosaic",
6363
)
6464

65+
66+
# This tracks the latest Mosaic IR version with a monthly delay.
67+
FWD_COMPAT_IR_VERSION = 3
68+
69+
6570
tpu_custom_call_p = core.Primitive("tpu_custom_call")
6671
tpu_custom_call_p.def_impl(
6772
functools.partial(xla.apply_primitive, tpu_custom_call_p))
@@ -407,6 +412,7 @@ def _lower_mosaic_module_to_asm(
407412
backend: str,
408413
device_type: str | None,
409414
kernel_name: str | None,
415+
ir_version: int | None = None,
410416
) -> tuple[ir.Module, tuple[bool, bool, bool, bool]]:
411417
has_communication, has_custom_barrier = tpu.private_has_communication(
412418
module.operation
@@ -438,8 +444,17 @@ def _lower_mosaic_module_to_asm(
438444
module_op = module.operation.clone()
439445
prev_allow_unregistered_dialects = ctx.allow_unregistered_dialects
440446
ctx.allow_unregistered_dialects = True
447+
# TODO(apaszke): Remove once the minimum jaxlib version is at least 0.4.37.
448+
if jax.version._version_as_tuple(jax.lib.__version__) < (0, 4, 37):
449+
target_version = ""
450+
else:
451+
target_version = (
452+
f"target-version={ir_version}" if ir_version is not None else ""
453+
)
441454
try:
442-
pipeline = PassManager.parse("builtin.module(mosaic-serde{serialize=true})")
455+
pipeline = PassManager.parse(
456+
"builtin.module(mosaic-serde{serialize=true " + target_version + "})"
457+
)
443458
pipeline.run(module_op)
444459
finally:
445460
ctx.allow_unregistered_dialects = prev_allow_unregistered_dialects
@@ -506,6 +521,7 @@ def _lower_to_custom_call_config(
506521
serialization_format: int | None,
507522
output_memory_spaces: tuple[MemorySpace | None, ...] | None = None,
508523
kernel_name: str | None = None,
524+
ir_version: int | None = None,
509525
) -> CustomCallBackendConfig:
510526
lowered_module_asm, (
511527
has_communication,
@@ -517,6 +533,7 @@ def _lower_to_custom_call_config(
517533
backend=backend,
518534
device_type=device_type,
519535
kernel_name=kernel_name,
536+
ir_version=ir_version,
520537
)
521538
return _lowered_to_custom_call_config(
522539
lowered_module_asm,
@@ -617,6 +634,7 @@ def lower_module_to_custom_call(
617634
serialization_format=serialization_format,
618635
output_memory_spaces=output_memory_spaces,
619636
kernel_name=kernel_name,
637+
ir_version=FWD_COMPAT_IR_VERSION if ctx.is_forward_compat() else None,
620638
)
621639
return _tpu_custom_call_lowering(
622640
ctx,

jaxlib/mosaic/dialect/tpu/tpu.td

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,10 @@ def DebugAssertInsertionPass : Pass<"debug-assert-insertion", "::mlir::func::Fun
780780
}
781781

782782
def MosaicSerdePass : Pass<"mosaic-serde", "::mlir::ModuleOp"> {
783-
let options = [Option<"serialize", "serialize", "bool", "", "">];
783+
let options = [
784+
Option<"serialize", "serialize", "bool", "", "">,
785+
Option<"target_version", "target-version", "int", "", ""> // Only used when serialize=true.
786+
];
784787
}
785788

786789
def LogicalToPhysicalDeviceIdPass : Pass<"logical-to-physical-device-id", "::mlir::func::FuncOp"> {

jaxlib/mosaic/dialect/tpu/transforms/serde.cc

Lines changed: 73 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ namespace {
4343

4444
constexpr std::string_view kMangledDialect = "stable_mosaic.";
4545
constexpr StringRef kVersionAttrName = "stable_mosaic.version";
46+
// When this is bumped, we should file a TODO to update the forward-compatible
47+
// version in tpu_custom_call.py in a month!
4648
constexpr int kVersion = 3;
4749

4850
StringRef mangle(StringRef name, std::string* storage) {
@@ -63,7 +65,7 @@ std::optional<StringRef> demangle(StringRef name) {
6365

6466
using rule_type = std::function<LogicalResult(Operation*, int)>;
6567

66-
LogicalResult enqueue_dma_rule(Operation* op, int version) {
68+
LogicalResult enqueue_dma_upgrade(Operation* op, int version) {
6769
// Added AttrSizedOperandSegments and core_id in version 2.
6870
if (version < 2) {
6971
if (op->getNumOperands() == 3) { // Local DMA.
@@ -84,17 +86,21 @@ LogicalResult enqueue_dma_rule(Operation* op, int version) {
8486
return success();
8587
}
8688

87-
LogicalResult semaphore_signal_rule(Operation* op, int version) {
89+
LogicalResult enqueue_dma_downgrade(Operation* op, int version) {
90+
if (version < 2) {
91+
return op->emitError("Downgrade to version ") << version << " unsupported";
92+
}
93+
return success();
94+
}
95+
96+
LogicalResult semaphore_signal_upgrade(Operation* op, int version) {
8897
// Added AttrSizedOperandSegments and core_id in version 2.
8998
if (version < 2) {
9099
if (op->getNumOperands() == 2) { // Local signal.
91100
op->setAttr(OpTrait::AttrSizedOperandSegments<
92101
EnqueueDMAOp>::getOperandSegmentSizeAttr(),
93102
mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 0, 0}));
94103
} else if (op->getNumOperands() == 3) { // Remote signal.
95-
// Hardcoding that one optional value is device_id, not core_id. This
96-
// could misinterpret sem_signals where core_id is specified, but
97-
// device_id isn't.
98104
op->setAttr(OpTrait::AttrSizedOperandSegments<
99105
EnqueueDMAOp>::getOperandSegmentSizeAttr(),
100106
mlir::DenseI32ArrayAttr::get(op->getContext(), {1, 1, 1, 0}));
@@ -105,7 +111,25 @@ LogicalResult semaphore_signal_rule(Operation* op, int version) {
105111
return success();
106112
}
107113

108-
LogicalResult vector_multi_dim_reduce_rule(Operation* op, int version) {
114+
LogicalResult semaphore_signal_downgrade(Operation* op, int version) {
115+
if (version < 2) {
116+
auto operands = op->getAttrOfType<mlir::DenseI32ArrayAttr>(
117+
OpTrait::AttrSizedOperandSegments<
118+
EnqueueDMAOp>::getOperandSegmentSizeAttr());
119+
if (!operands || operands.size() != 4) {
120+
return op->emitError("Missing or invalid AttrSizedOperandSegments");
121+
}
122+
if (operands[3]) {
123+
return op->emitError("Downgrade to version ")
124+
<< version << " impossible: core_id is set";
125+
}
126+
op->removeAttr(OpTrait::AttrSizedOperandSegments<
127+
EnqueueDMAOp>::getOperandSegmentSizeAttr());
128+
}
129+
return success();
130+
}
131+
132+
LogicalResult vector_multi_dim_reduce_upgrade(Operation* op, int version) {
109133
// Changed reductions_dims from ArrayAttr of IntegerAttrs to DenseI64ArrayAttr
110134
// in version 3.
111135
if (version < 3) {
@@ -133,21 +157,49 @@ LogicalResult vector_multi_dim_reduce_rule(Operation* op, int version) {
133157
return success();
134158
}
135159

160+
LogicalResult vector_multi_dim_reduce_downgrade(Operation* op, int version) {
161+
if (version < 3) {
162+
return op->emitError("Downgrade to version ") << version << " unsupported";
163+
}
164+
return success();
165+
}
166+
136167
const llvm::StringMap<rule_type>& upgrade_rules() {
137168
static auto rules = new llvm::StringMap<rule_type>{
138-
{EnqueueDMAOp::getOperationName(), enqueue_dma_rule},
139-
{SemaphoreSignalOp::getOperationName(), semaphore_signal_rule},
169+
{EnqueueDMAOp::getOperationName(), enqueue_dma_upgrade},
170+
{SemaphoreSignalOp::getOperationName(), semaphore_signal_upgrade},
140171
{vector::MultiDimReductionOp::getOperationName(),
141-
vector_multi_dim_reduce_rule}
172+
vector_multi_dim_reduce_upgrade}
142173
};
143174
return *rules;
144175
}
145176

177+
const llvm::StringMap<rule_type>& downgrade_rules() {
178+
static auto rules = new llvm::StringMap<rule_type>{
179+
{EnqueueDMAOp::getOperationName(), enqueue_dma_downgrade},
180+
{SemaphoreSignalOp::getOperationName(), semaphore_signal_downgrade},
181+
{vector::MultiDimReductionOp::getOperationName(),
182+
vector_multi_dim_reduce_downgrade}};
183+
return *rules;
184+
}
185+
146186
struct MosaicSerdePass : public impl::MosaicSerdePassBase<MosaicSerdePass> {
147187
using Base::Base;
148188

149189
void runOnOperation() override {
150190
ModuleOp module = getOperation();
191+
if (!serialize.hasValue()) {
192+
module.emitError("serialize option must be specified");
193+
return signalPassFailure();
194+
}
195+
int serialize_version =
196+
target_version.hasValue() ? target_version : kVersion;
197+
if (serialize && serialize_version > kVersion) {
198+
module.emitError("The highest supported version is ")
199+
<< kVersion << " but requested serialization at version "
200+
<< serialize_version;
201+
return signalPassFailure();
202+
}
151203
if (serialize && !module->getContext()->allowsUnregisteredDialects()) {
152204
module.emitError() << "Cannot serialize within a context that does not "
153205
"allow unregistered dialects.";
@@ -159,7 +211,7 @@ struct MosaicSerdePass : public impl::MosaicSerdePassBase<MosaicSerdePass> {
159211
module->setAttr(
160212
kVersionAttrName,
161213
IntegerAttr::get(IntegerType::get(module->getContext(), 64),
162-
kVersion));
214+
serialize_version));
163215
} else {
164216
IntegerAttr version_attr =
165217
module->getAttrOfType<IntegerAttr>(kVersionAttrName);
@@ -178,7 +230,7 @@ struct MosaicSerdePass : public impl::MosaicSerdePassBase<MosaicSerdePass> {
178230
module->removeAttr(kVersionAttrName);
179231
}
180232
std::string name_storage;
181-
auto result = module.walk([this, &name_storage, version](Operation* op) {
233+
auto result = module.walk([&](Operation* op) {
182234
if (isa<ModuleOp>(op)) { // Don't mangle the ModuleOp itself.
183235
return WalkResult::advance();
184236
}
@@ -210,6 +262,16 @@ struct MosaicSerdePass : public impl::MosaicSerdePassBase<MosaicSerdePass> {
210262
auto new_op = Operation::create(
211263
op->getLoc(), *new_name, op->getResultTypes(), op->getOperands(),
212264
op->getAttrs(), nullptr, op->getSuccessors(), op->getRegions());
265+
// Downgrade the op to the target version, if needed.
266+
if (serialize && kVersion != serialize_version) {
267+
if (const auto rule =
268+
downgrade_rules().find(op->getName().getStringRef());
269+
rule != downgrade_rules().end()) {
270+
if (rule->second(new_op, serialize_version).failed()) {
271+
return WalkResult::interrupt();
272+
}
273+
}
274+
}
213275
op->getBlock()->getOperations().insertAfter(Block::iterator(op), new_op);
214276
op->replaceAllUsesWith(new_op->getResults());
215277
op->erase();

0 commit comments

Comments
 (0)