@@ -491,6 +491,12 @@ HloRunnerPjRt::CreateExecutable(std::unique_ptr<HloModule> module,
491491 absl::StatusOr<std::unique_ptr<PjRtExecutable>> pjrt_executable =
492492 pjrt_client_->Compile (computation, compile_options);
493493 if (pjrt_executable.ok ()) {
494+ absl::StatusOr<std::vector<std::shared_ptr<HloModule>>> hlo_modules =
495+ pjrt_executable->get ()->GetHloModules ();
496+ if (hlo_modules.ok () && !hlo_modules->empty ()) {
497+ std::shared_ptr<HloModule> exe_module = (*hlo_modules)[0 ];
498+ exe_module->mutable_config ().set_seed (module ->config ().seed ());
499+ }
494500 return std::make_unique<HloRunnerPjRtExecutable>(
495501 this , *std::move (pjrt_executable));
496502 }
@@ -502,6 +508,14 @@ HloRunnerPjRt::CreateExecutable(std::unique_ptr<HloModule> module,
502508 TF_ASSIGN_OR_RETURN (
503509 std::unique_ptr<PjRtLoadedExecutable> pjrt_loaded_executable,
504510 pjrt_client_->CompileAndLoad (computation, std::move (compile_options)));
511+ if (pjrt_loaded_executable != nullptr ) {
512+ absl::StatusOr<std::vector<std::shared_ptr<HloModule>>> hlo_modules =
513+ pjrt_loaded_executable->GetHloModules ();
514+ if (hlo_modules.ok () && !hlo_modules->empty ()) {
515+ std::shared_ptr<HloModule> exe_module = (*hlo_modules)[0 ];
516+ exe_module->mutable_config ().set_seed (module ->config ().seed ());
517+ }
518+ }
505519 return std::make_unique<HloRunnerPjRtExecutable>(
506520 this , std::move (pjrt_loaded_executable));
507521}
0 commit comments