Skip to content

Commit 849f758

Browse files
GleasonKGoogle-ML-Automation
authored andcommitted
[StableHLO] Pin StableHLOv0.19.0 for older PJRT plugins.
This is a temporary measure to allow plugins to update to latest jaxlib. PiperOrigin-RevId: 681464871
1 parent 8cb57d6 commit 849f758

File tree

4 files changed

+20
-13
lines changed

4 files changed

+20
-13
lines changed

xla/pjrt/mlir_to_hlo.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,15 +244,21 @@ absl::Status UpgradeVersionedStablehlo(mlir::ModuleOp mlir_module) {
244244
return absl::OkStatus();
245245
}
246246

247-
std::string GetDefaultStablehloVersion() {
247+
std::string GetDefaultStablehloVersion(std::optional<int64_t> plugin_version) {
248+
// TODO: (b/370803410) Use WEEK_12 in PJRT, some plugins were not up to date,
249+
// so temporarily using 1.0.0 to allow them time for a new release.
250+
// PJRT v54 released Jun 10, so most plugins should use WEEK_12 by default.
251+
if (plugin_version.has_value() && plugin_version.value() < 54) {
252+
return "0.19.0";
253+
}
254+
248255
// This version must be >=12w old.
249256
return mlir::vhlo::Version::fromCompatibilityRequirement(
250257
mlir::vhlo::Version::CompatibilityRequirement::WEEK_12)
251258
.toString();
252259
}
253260

254261
absl::StatusOr<std::string> Serialize(mlir::ModuleOp module,
255-
std::optional<int64_t> /*plugin_version*/,
256262
absl::string_view target, bool inplace) {
257263
// Current PJRT users expect 12 weeks forward compat, VHLO provides this
258264
// compat.

xla/pjrt/mlir_to_hlo.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ absl::Status ParseMlirModuleStringAndConvertToXlaComputation(
4242

4343
// Returns a version of StableHLO ~12w old, for forward compatibility with PJRT
4444
// plugins on a quarterly update cycle.
45-
std::string GetDefaultStablehloVersion();
45+
std::string GetDefaultStablehloVersion(
46+
std::optional<int64_t> plugin_version = std::nullopt);
4647

4748
// Serialize using MLIR Bytecode Format which does not guarantee forward or
4849
// backward compatiblity of the dialects used. If passing StableHLO with forward
@@ -52,7 +53,6 @@ std::string GetDefaultStablehloVersion();
5253
// For plugin_version < 41, returns `SerializeUsingNativeBytecode`.
5354
// For plugin_version >= 41, returns `SerializeUsingVersionedStablehlo`.
5455
absl::StatusOr<std::string> Serialize(mlir::ModuleOp mlir_module,
55-
std::optional<int64_t> plugin_version,
5656
absl::string_view target,
5757
bool inplace = false);
5858

xla/pjrt/mlir_to_hlo_test.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ TEST(MlirToHloTest, StablehloTest) {
5151
mlir::MLIRContext context;
5252
TF_ASSERT_OK_AND_ASSIGN(mlir::OwningOpRef<mlir::ModuleOp> module,
5353
ParseMlirModuleString(kProgram, context));
54-
TF_ASSERT_OK_AND_ASSIGN(std::string blob, Serialize(*module, 47, "1.0.0"));
54+
TF_ASSERT_OK_AND_ASSIGN(std::string blob, Serialize(*module, "1.0.0"));
5555

5656
// StableHLO uses VHLO for PJRT serialization.
5757
EXPECT_THAT(blob, IsVhloArtifact("1.0.0"));
@@ -69,7 +69,7 @@ TEST(MlirToHloTest, ChloTest) {
6969
mlir::MLIRContext context;
7070
TF_ASSERT_OK_AND_ASSIGN(mlir::OwningOpRef<mlir::ModuleOp> module,
7171
ParseMlirModuleString(kProgram, context));
72-
TF_ASSERT_OK_AND_ASSIGN(std::string blob, Serialize(*module, 47, "1.0.0"));
72+
TF_ASSERT_OK_AND_ASSIGN(std::string blob, Serialize(*module, "1.0.0"));
7373

7474
// CHLO decomposes to StableHLO, so uses VHLO serialization.
7575
EXPECT_THAT(blob, IsVhloArtifact("1.0.0"));
@@ -86,7 +86,7 @@ TEST(MlirToHloTest, ChloTanOpTest) {
8686
mlir::MLIRContext context;
8787
TF_ASSERT_OK_AND_ASSIGN(mlir::OwningOpRef<mlir::ModuleOp> module,
8888
ParseMlirModuleString(kProgram, context));
89-
TF_ASSERT_OK_AND_ASSIGN(std::string blob, Serialize(*module, 47, "1.0.0"));
89+
TF_ASSERT_OK_AND_ASSIGN(std::string blob, Serialize(*module, "1.0.0"));
9090

9191
// CHLO decomposes to StableHLO, so uses VHLO serialization.
9292
EXPECT_THAT(blob, IsVhloArtifact("1.0.0"));
@@ -104,7 +104,7 @@ TEST(MlirToHloTest, MhloTest) {
104104
mlir::MLIRContext context;
105105
TF_ASSERT_OK_AND_ASSIGN(mlir::OwningOpRef<mlir::ModuleOp> module,
106106
ParseMlirModuleString(kProgram, context));
107-
TF_ASSERT_OK_AND_ASSIGN(std::string blob, Serialize(*module, 47, "1.0.0"));
107+
TF_ASSERT_OK_AND_ASSIGN(std::string blob, Serialize(*module, "1.0.0"));
108108

109109
// MHLO and other dialects use native MLIR bytecode, not VHLO.
110110
EXPECT_THAT(blob, Not(IsVhloArtifact("1.0.0")));

xla/pjrt/pjrt_c_api_client.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -395,8 +395,9 @@ absl::StatusOr<std::unique_ptr<PjRtLoadedExecutable>> PjRtCApiClient::Compile(
395395
if (!pjrt_c_api()) llvm::report_fatal_error("pjrt_c_api is null");
396396
TF_ASSIGN_OR_RETURN(
397397
std::string serialized,
398-
xla::Serialize(module, plugin_attributes()->pjrt_c_api_minor_version,
399-
xla::GetDefaultStablehloVersion()));
398+
xla::Serialize(module,
399+
xla::GetDefaultStablehloVersion(
400+
plugin_attributes()->pjrt_c_api_minor_version)));
400401
std::string format(pjrt::kMlirFormat);
401402
return InitializeArgsAndCompile(this, c_api_, c_client_.get(), options,
402403
serialized, format);
@@ -2311,9 +2312,9 @@ absl::StatusOr<std::unique_ptr<PjRtExecutable>> PjRtCApiCompiler::Compile(
23112312
if (client) {
23122313
plugin_version = client->plugin_attributes()->pjrt_c_api_minor_version;
23132314
}
2314-
TF_ASSIGN_OR_RETURN(std::string serialized,
2315-
xla::Serialize(module, plugin_version,
2316-
xla::GetDefaultStablehloVersion()));
2315+
TF_ASSIGN_OR_RETURN(
2316+
std::string serialized,
2317+
xla::Serialize(module, xla::GetDefaultStablehloVersion(plugin_version)));
23172318
std::string format(pjrt::kMlirFormat);
23182319
return InitializeArgsAndCompileAot(c_api_, client, options, topology,
23192320
serialized, format);

0 commit comments

Comments
 (0)