Skip to content

Commit ca8d1ec

Browse files
hlo_runner_pjrt should keep hlo module config's seed() in optimized hlo module
PiperOrigin-RevId: 837185528
1 parent 1351df8 commit ca8d1ec

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

xla/service/hlo_runner_pjrt.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,6 +491,12 @@ HloRunnerPjRt::CreateExecutable(std::unique_ptr<HloModule> module,
491491
absl::StatusOr<std::unique_ptr<PjRtExecutable>> pjrt_executable =
492492
pjrt_client_->Compile(computation, compile_options);
493493
if (pjrt_executable.ok()) {
494+
absl::StatusOr<std::vector<std::shared_ptr<HloModule>>> hlo_modules =
495+
pjrt_executable->get()->GetHloModules();
496+
if (hlo_modules.ok() && !hlo_modules->empty()) {
497+
std::shared_ptr<HloModule> exe_module = (*hlo_modules)[0];
498+
exe_module->mutable_config().set_seed(module->config().seed());
499+
}
494500
return std::make_unique<HloRunnerPjRtExecutable>(
495501
this, *std::move(pjrt_executable));
496502
}
@@ -502,6 +508,14 @@ HloRunnerPjRt::CreateExecutable(std::unique_ptr<HloModule> module,
502508
TF_ASSIGN_OR_RETURN(
503509
std::unique_ptr<PjRtLoadedExecutable> pjrt_loaded_executable,
504510
pjrt_client_->CompileAndLoad(computation, std::move(compile_options)));
511+
if (pjrt_loaded_executable != nullptr) {
512+
absl::StatusOr<std::vector<std::shared_ptr<HloModule>>> hlo_modules =
513+
pjrt_loaded_executable->GetHloModules();
514+
if (hlo_modules.ok() && !hlo_modules->empty()) {
515+
std::shared_ptr<HloModule> exe_module = (*hlo_modules)[0];
516+
exe_module->mutable_config().set_seed(module->config().seed());
517+
}
518+
}
505519
return std::make_unique<HloRunnerPjRtExecutable>(
506520
this, std::move(pjrt_loaded_executable));
507521
}

0 commit comments

Comments
 (0)