Skip to content

Commit a14e696

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[mosaic] Migrated the serialization pass from codegen to pass_boilerplate.h
This prepares teh generalization of the serialization pass to handle both Mosaic TPU and GPU. PiperOrigin-RevId: 705628923
1 parent 97459ba commit a14e696

File tree

12 files changed

+199
-115
lines changed

12 files changed

+199
-115
lines changed

jaxlib/BUILD

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,17 @@ cc_library(
170170
],
171171
)
172172

173+
cc_library(
174+
name = "pass_boilerplate",
175+
hdrs = ["pass_boilerplate.h"],
176+
# compatible with libtpu
177+
deps = [
178+
"@llvm-project//mlir:IR",
179+
"@llvm-project//mlir:Pass",
180+
"@llvm-project//mlir:Support",
181+
],
182+
)
183+
173184
cc_library(
174185
name = "handle_pool",
175186
hdrs = ["handle_pool.h"],

jaxlib/mlir/_mlir_libs/tpu_ext.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,7 @@ MlirContext getDefaultContext() {
316316

317317
PYBIND11_MODULE(_tpu_ext, m, py::mod_gil_not_used()) {
318318
mlirRegisterTPUPasses(); // Register all passes on load.
319+
mlirTpuRegisterMosaicSerdePass();
319320

320321
py::class_<MlirTpuApplyVectorLayoutContext>(m, "ApplyVectorLayoutCtx",
321322
py::module_local())

jaxlib/mosaic/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ cc_library(
5656
# compatible with libtpu
5757
deps = [
5858
":tpu_inc_gen",
59+
"//jaxlib:pass_boilerplate",
5960
"@com_google_absl//absl/algorithm:container",
6061
"@com_google_absl//absl/container:flat_hash_set",
6162
"@com_google_absl//absl/hash",

jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ limitations under the License.
4444
#include "jaxlib/mosaic/dialect/tpu/layout.h"
4545
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
4646
#include "jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.h"
47+
#include "jaxlib/mosaic/dialect/tpu/transforms/serde.h"
4748
#include "xla/array.h"
4849

4950
// TODO(tlongeri): null pointer checks?
@@ -408,6 +409,10 @@ MlirValue mlirTpuRelayout(MlirTpuInsertionPoint insertion_point, MlirValue val,
408409
}
409410
}
410411

412+
MLIR_CAPI_EXPORTED void mlirTpuRegisterMosaicSerdePass() {
413+
mlir::tpu::registerMosaicSerdePass();
414+
}
415+
411416
#include "mlir/CAPI/Pass.h" // IWYU pragma: keep
412417
#include "mlir/CAPI/Support.h" // IWYU pragma: keep
413418

jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
#ifndef JAXLIB_MOSAIC_DIALECT_TPU_INTEGRATIONS_C_TPU_DIALECT_H_
2020
#define JAXLIB_MOSAIC_DIALECT_TPU_INTEGRATIONS_C_TPU_DIALECT_H_
2121

22+
#include "jaxlib/mosaic/dialect/tpu/integrations/c/tpu_dialect.h"
2223
#ifndef __cplusplus
2324
#include <stdbool.h>
2425
#endif
@@ -234,6 +235,10 @@ MLIR_CAPI_EXPORTED MlirValue
234235
mlirTpuRelayout(MlirTpuInsertionPoint insertion_point, MlirValue val,
235236
MlirTpuVectorLayout src, MlirTpuVectorLayout dst,
236237
MlirTpuApplyVectorLayoutContext ctx);
238+
239+
240+
MLIR_CAPI_EXPORTED void mlirTpuRegisterMosaicSerdePass();
241+
237242
#ifdef __cplusplus
238243
}
239244
#endif

jaxlib/mosaic/dialect/tpu/tpu.td

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -786,13 +786,6 @@ def DebugAssertInsertionPass : Pass<"debug-assert-insertion", "::mlir::func::Fun
786786
let constructor = "::mlir::tpu::createDebugAssertInsertionPass()";
787787
}
788788

789-
def MosaicSerdePass : Pass<"mosaic-serde", "::mlir::ModuleOp"> {
790-
let options = [
791-
Option<"serialize", "serialize", "bool", "", "">,
792-
Option<"target_version", "target-version", "int", "", ""> // Only used when serialize=true.
793-
];
794-
}
795-
796789
def LogicalToPhysicalDeviceIdPass : Pass<"logical-to-physical-device-id", "::mlir::func::FuncOp"> {
797790
let dependentDialects = [
798791
"::mlir::func::FuncDialect",

jaxlib/mosaic/dialect/tpu/tpu_dialect.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ limitations under the License.
3131
#include "mlir/include/mlir/Support/LogicalResult.h"
3232
#include "jaxlib/mosaic/dialect/tpu/layout.h"
3333
#include "jaxlib/mosaic/dialect/tpu/tpu_enums.h.inc"
34+
#include "jaxlib/mosaic/dialect/tpu/transforms/serde.h"
3435
#include "xla/layout.h"
3536

3637
namespace mlir::tpu {

jaxlib/mosaic/dialect/tpu/transforms/serde.cc

Lines changed: 89 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,24 @@ See the License for the specific language governing permissions and
1313
limitations under the License.
1414
==============================================================================*/
1515

16-
// We need to keep some extra headers for the code in tpu_passes.h.inc.
16+
#include "jaxlib/mosaic/dialect/tpu/transforms/serde.h"
1717

18-
#include <memory> // IWYU pragma: keep
18+
#include <cstdint>
19+
#include <functional>
1920
#include <optional>
2021
#include <string>
2122
#include <string_view>
23+
#include <vector>
2224

25+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2326
#include "mlir/IR/BuiltinAttributes.h"
2427
#include "mlir/IR/BuiltinOps.h"
25-
#include "mlir/Dialect/Vector/IR/VectorOps.h"
2628
#include "mlir/IR/OperationSupport.h"
2729
#include "mlir/IR/Value.h"
2830
#include "mlir/IR/Visitors.h"
29-
#include "mlir/Pass/Pass.h" // IWYU pragma: keep
3031
#include "mlir/Support/LLVM.h"
32+
#include "llvm/include/llvm/ADT/StringMap.h"
33+
#include "mlir/include/mlir/IR/Attributes.h"
3134
#include "mlir/include/mlir/IR/BuiltinAttributes.h"
3235
#include "mlir/include/mlir/IR/OpDefinition.h"
3336
#include "mlir/include/mlir/IR/OperationSupport.h"
@@ -36,9 +39,6 @@ limitations under the License.
3639

3740
namespace mlir::tpu {
3841

39-
#define GEN_PASS_DEF_MOSAICSERDEPASS
40-
#include "jaxlib/mosaic/dialect/tpu/tpu_passes.h.inc"
41-
4242
namespace {
4343

4444
constexpr std::string_view kMangledDialect = "stable_mosaic.";
@@ -183,107 +183,101 @@ const llvm::StringMap<rule_type>& downgrade_rules() {
183183
return *rules;
184184
}
185185

186-
struct MosaicSerdePass : public impl::MosaicSerdePassBase<MosaicSerdePass> {
187-
using Base::Base;
186+
} // namespace
188187

189-
void runOnOperation() override {
190-
ModuleOp module = getOperation();
191-
if (!serialize.hasValue()) {
192-
module.emitError("serialize option must be specified");
193-
return signalPassFailure();
194-
}
195-
int serialize_version =
196-
target_version.hasValue() ? target_version : kVersion;
197-
if (serialize && serialize_version > kVersion) {
198-
module.emitError("The highest supported version is ")
199-
<< kVersion << " but requested serialization at version "
200-
<< serialize_version;
201-
return signalPassFailure();
188+
void MosaicSerdePass::runOnOperation() {
189+
ModuleOp module = getOperation();
190+
if (!serialize.hasValue()) {
191+
module.emitError("serialize option must be specified");
192+
return signalPassFailure();
193+
}
194+
int serialize_version = target_version.hasValue() ? target_version : kVersion;
195+
if (serialize && serialize_version > kVersion) {
196+
module.emitError("The highest supported version is ")
197+
<< kVersion << " but requested serialization at version "
198+
<< serialize_version;
199+
return signalPassFailure();
200+
}
201+
if (serialize && !module->getContext()->allowsUnregisteredDialects()) {
202+
module.emitError() << "Cannot serialize within a context that does not "
203+
"allow unregistered dialects.";
204+
signalPassFailure();
205+
return;
206+
}
207+
int version = kVersion;
208+
if (serialize) {
209+
module->setAttr(kVersionAttrName,
210+
IntegerAttr::get(IntegerType::get(module->getContext(), 64),
211+
serialize_version));
212+
} else {
213+
IntegerAttr version_attr =
214+
module->getAttrOfType<IntegerAttr>(kVersionAttrName);
215+
if (!version_attr) {
216+
module->emitError("Missing or invalid Mosaic version attribute");
217+
signalPassFailure();
218+
return;
202219
}
203-
if (serialize && !module->getContext()->allowsUnregisteredDialects()) {
204-
module.emitError() << "Cannot serialize within a context that does not "
205-
"allow unregistered dialects.";
220+
if (version_attr.getInt() > kVersion) {
221+
module->emitError("Unsupported Mosaic version: expected <= ")
222+
<< kVersion << " but got " << version_attr.getInt();
206223
signalPassFailure();
207224
return;
208225
}
209-
int version = kVersion;
226+
version = version_attr.getInt();
227+
module->removeAttr(kVersionAttrName);
228+
}
229+
std::string name_storage;
230+
auto result = module.walk([&](Operation* op) {
231+
if (isa<ModuleOp>(op)) { // Don't mangle the ModuleOp itself.
232+
return WalkResult::advance();
233+
}
234+
std::optional<OperationName> new_name;
210235
if (serialize) {
211-
module->setAttr(
212-
kVersionAttrName,
213-
IntegerAttr::get(IntegerType::get(module->getContext(), 64),
214-
serialize_version));
236+
auto new_name_str = mangle(op->getName().getStringRef(), &name_storage);
237+
new_name = OperationName(new_name_str, op->getContext());
215238
} else {
216-
IntegerAttr version_attr =
217-
module->getAttrOfType<IntegerAttr>(kVersionAttrName);
218-
if (!version_attr) {
219-
module->emitError("Missing or invalid Mosaic version attribute");
220-
signalPassFailure();
221-
return;
222-
}
223-
if (version_attr.getInt() > kVersion) {
224-
module->emitError("Unsupported Mosaic version: expected <= ")
225-
<< kVersion << " but got " << version_attr.getInt();
226-
signalPassFailure();
227-
return;
228-
}
229-
version = version_attr.getInt();
230-
module->removeAttr(kVersionAttrName);
231-
}
232-
std::string name_storage;
233-
auto result = module.walk([&](Operation* op) {
234-
if (isa<ModuleOp>(op)) { // Don't mangle the ModuleOp itself.
235-
return WalkResult::advance();
236-
}
237-
std::optional<OperationName> new_name;
238-
if (serialize) {
239-
auto new_name_str = mangle(op->getName().getStringRef(), &name_storage);
240-
new_name = OperationName(new_name_str, op->getContext());
241-
} else {
242-
if (auto demangled = demangle(op->getName().getStringRef())) {
243-
auto new_name_str = *demangled;
244-
if (auto registered = RegisteredOperationName::lookup(
245-
new_name_str, op->getContext())) {
246-
new_name = *registered;
247-
} else {
248-
new_name = OperationName(new_name_str, op->getContext());
249-
}
239+
if (auto demangled = demangle(op->getName().getStringRef())) {
240+
auto new_name_str = *demangled;
241+
if (auto registered = RegisteredOperationName::lookup(
242+
new_name_str, op->getContext())) {
243+
new_name = *registered;
250244
} else {
251-
op->emitError("Operation not in a serialized form");
252-
return WalkResult::interrupt();
245+
new_name = OperationName(new_name_str, op->getContext());
253246
}
254-
// Upgrade the op to the current version, if needed.
255-
if (const auto rule = upgrade_rules().find(new_name->getStringRef());
256-
rule != upgrade_rules().end()) {
257-
if (rule->second(op, version).failed()) {
258-
return WalkResult::interrupt();
259-
}
247+
} else {
248+
op->emitError("Operation not in a serialized form");
249+
return WalkResult::interrupt();
250+
}
251+
// Upgrade the op to the current version, if needed.
252+
if (const auto rule = upgrade_rules().find(new_name->getStringRef());
253+
rule != upgrade_rules().end()) {
254+
if (rule->second(op, version).failed()) {
255+
return WalkResult::interrupt();
260256
}
261257
}
262-
auto new_op = Operation::create(
263-
op->getLoc(), *new_name, op->getResultTypes(), op->getOperands(),
264-
op->getAttrs(), nullptr, op->getSuccessors(), op->getRegions());
265-
// Downgrade the op to the target version, if needed.
266-
if (serialize && kVersion != serialize_version) {
267-
if (const auto rule =
268-
downgrade_rules().find(op->getName().getStringRef());
269-
rule != downgrade_rules().end()) {
270-
if (rule->second(new_op, serialize_version).failed()) {
271-
return WalkResult::interrupt();
272-
}
258+
}
259+
auto new_op = Operation::create(
260+
op->getLoc(), *new_name, op->getResultTypes(), op->getOperands(),
261+
op->getAttrs(), nullptr, op->getSuccessors(), op->getRegions());
262+
// Downgrade the op to the target version, if needed.
263+
if (serialize && kVersion != serialize_version) {
264+
if (const auto rule =
265+
downgrade_rules().find(op->getName().getStringRef());
266+
rule != downgrade_rules().end()) {
267+
if (rule->second(new_op, serialize_version).failed()) {
268+
return WalkResult::interrupt();
273269
}
274270
}
275-
op->getBlock()->getOperations().insertAfter(Block::iterator(op), new_op);
276-
op->replaceAllUsesWith(new_op->getResults());
277-
op->erase();
278-
return WalkResult::advance();
279-
});
280-
if (result.wasInterrupted()) {
281-
signalPassFailure();
282-
return;
283271
}
272+
op->getBlock()->getOperations().insertAfter(Block::iterator(op), new_op);
273+
op->replaceAllUsesWith(new_op->getResults());
274+
op->erase();
275+
return WalkResult::advance();
276+
});
277+
if (result.wasInterrupted()) {
278+
signalPassFailure();
279+
return;
284280
}
285-
};
286-
287-
} // namespace
281+
}
288282

289-
} // namespace mlir::tpu
283+
} // namespace mlir::tpu
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
#ifndef THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_SERDE_H_
2+
#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_SERDE_H_
3+
4+
#include <memory>
5+
#include <utility>
6+
7+
#include "llvm/include/llvm/ADT/StringRef.h"
8+
#include "llvm/include/llvm/Support/CommandLine.h"
9+
#include "mlir/include/mlir/Interfaces/DataLayoutInterfaces.h"
10+
#include "mlir/include/mlir/Pass/Pass.h"
11+
#include "mlir/include/mlir/Pass/PassRegistry.h"
12+
#include "jaxlib/pass_boilerplate.h"
13+
14+
namespace mlir::tpu {
15+
16+
struct MosaicSerdePassOptions {
17+
bool serialize;
18+
int target_version;
19+
};
20+
21+
struct MosaicSerdePass : public jaxlib::mlir::Pass<MosaicSerdePass, ModuleOp> {
22+
using jaxlib::mlir::Pass<MosaicSerdePass, ModuleOp>::Pass;
23+
24+
static constexpr llvm::StringLiteral kArgumentName = "mosaic-serde";
25+
static constexpr llvm::StringLiteral kPassName = "MosaicSerdePass";
26+
27+
MosaicSerdePass() = default;
28+
29+
explicit MosaicSerdePass(MosaicSerdePassOptions options) {
30+
serialize = options.serialize;
31+
target_version = options.target_version;
32+
}
33+
34+
MosaicSerdePass(const MosaicSerdePass &other) {
35+
serialize = other.serialize;
36+
target_version = other.target_version;
37+
}
38+
39+
MosaicSerdePass &operator=(const MosaicSerdePass &other) {
40+
serialize = other.serialize;
41+
target_version = other.target_version;
42+
return *this;
43+
}
44+
45+
void runOnOperation();
46+
47+
protected:
48+
::mlir::Pass::Option<bool> serialize{*this, "serialize", llvm::cl::desc("")};
49+
::mlir::Pass::Option<int> target_version{*this, "target-version",
50+
llvm::cl::desc("")};
51+
};
52+
53+
inline std::unique_ptr<::mlir::Pass> createMosaicSerdePass() {
54+
return std::make_unique<MosaicSerdePass>();
55+
}
56+
57+
inline std::unique_ptr<::mlir::Pass> createMosaicSerdePass(
58+
MosaicSerdePassOptions options) {
59+
return std::make_unique<MosaicSerdePass>(std::move(options));
60+
}
61+
62+
inline void registerMosaicSerdePass() {
63+
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
64+
return createMosaicSerdePass();
65+
});
66+
}
67+
68+
} // namespace mlir::tpu
69+
70+
#endif // THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_TRANSFORMS_SERDE_H_

0 commit comments

Comments
 (0)