@@ -296,17 +296,34 @@ absl::StatusOr<LoadedExecutableRef> PjRtLoadedExecutable::Create(
296296 TF_ASSIGN_OR_RETURN (
297297 auto result_memory_kinds,
298298 GetFirstModuleOutputMemoryKinds (pjrt_loaded_executable.get ()));
299- TF_ASSIGN_OR_RETURN (auto hlo_modules,
300- pjrt_loaded_executable->GetHloModules ());
301- if (hlo_modules.empty ()) {
302- return FailedPrecondition (" Requires at least one HloModule." );
303- }
304- TF_ASSIGN_OR_RETURN (std::vector<xla::LayoutMode> output_layout_modes,
305- GetLayoutModes (*hlo_modules.front (), " out_layout_modes" ,
306- result_element_types.size ()));
307- TF_ASSIGN_OR_RETURN (auto output_layouts,
308- GetFirstModuleOutputLayouts (pjrt_loaded_executable.get (),
309- output_layout_modes));
299+ // Obtaining output layout modes and output layouts directly from
300+ // `PjRtLoadedExecutable` may fail because the currently PjRt implementations
301+ // often fetch and serialize the optimized HLO. For now, we gracefully
302+ // handle it by omitting output layouts at creation time and using output
303+ // `PjRtBuffer`'s concrete layouts.
304+ // TODO(hyeontaek): Add a way to obtain output layout modes and
305+ // `PjRtLoadedExecutable::GetOutputLayouts()` without causing the optimized
306+ // HLO to be serialized and fetched.
307+ std::optional<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
308+ output_layouts;
309+ absl::StatusOr<std::vector<std::shared_ptr<HloModule>>> hlo_modules =
310+ pjrt_loaded_executable->GetHloModules ();
311+ if (hlo_modules.ok ()) {
312+ if (hlo_modules->empty ()) {
313+ return FailedPrecondition (" Requires at least one HloModule." );
314+ }
315+ absl::StatusOr<std::vector<xla::LayoutMode>> output_layout_modes =
316+ GetLayoutModes (*hlo_modules->front (), " out_layout_modes" ,
317+ result_element_types.size ());
318+ if (output_layout_modes.ok ()) {
319+ absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
320+ first_module_output_layouts = GetFirstModuleOutputLayouts (
321+ pjrt_loaded_executable.get (), *output_layout_modes);
322+ if (first_module_output_layouts.ok ()) {
323+ output_layouts = *std::move (first_module_output_layouts);
324+ }
325+ }
326+ }
310327 return CreateInternal (client, std::move (pjrt_loaded_executable),
311328 result_element_types, result_dimensions,
312329 /* result_hlo_sharding=*/ std::nullopt ,
@@ -352,8 +369,8 @@ absl::StatusOr<LoadedExecutableRef> PjRtLoadedExecutable::Create(
352369 // will use the MLIR as scratch space, or possibly even deallocate it.
353370 TF_ASSIGN_OR_RETURN (const std::vector<xla::Shape> result_shapes,
354371 ResultShapesOfModule (module ));
355- TF_ASSIGN_OR_RETURN ( const std::vector<xla::LayoutMode> output_layout_modes,
356- GetOutputLayoutModes (module ) );
372+ absl::StatusOr< std::vector<xla::LayoutMode>> output_layout_modes =
373+ GetOutputLayoutModes (module );
357374
358375 TF_ASSIGN_OR_RETURN (auto pjrt_loaded_executable,
359376 client->pjrt_client ()->CompileAndLoad (
@@ -372,9 +389,24 @@ absl::StatusOr<LoadedExecutableRef> PjRtLoadedExecutable::Create(
372389 TF_ASSIGN_OR_RETURN (
373390 auto result_memory_kinds,
374391 GetFirstModuleOutputMemoryKinds (pjrt_loaded_executable.get ()));
375- TF_ASSIGN_OR_RETURN (auto output_layouts,
376- GetFirstModuleOutputLayouts (
377- pjrt_loaded_executable.get (), output_layout_modes));
392+ // Obtaining output layout modes and output layouts directly from
393+ // `PjRtLoadedExecutable` may fail because the currently PjRt
394+ // implementations often fetch and serialize the optimized HLO. For now, we
395+ // gracefully handle it by omitting output layouts at creation time and
396+ // using output `PjRtBuffer`'s concrete layouts.
397+ // TODO(hyeontaek): Add a way to obtain output layout modes and
398+ // `PjRtLoadedExecutable::GetOutputLayouts()` without causing the optimized
399+ // HLO to be serialized and fetched.
400+ std::optional<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
401+ output_layouts;
402+ if (output_layout_modes.ok ()) {
403+ absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
404+ first_module_output_layouts = GetFirstModuleOutputLayouts (
405+ pjrt_loaded_executable.get (), *output_layout_modes);
406+ if (first_module_output_layouts.ok ()) {
407+ output_layouts = *std::move (first_module_output_layouts);
408+ }
409+ }
378410 return CreateInternal (client, std::move (pjrt_loaded_executable),
379411 result_element_types, result_dimensions,
380412 /* result_hlo_sharding=*/ std::nullopt ,
@@ -405,9 +437,24 @@ absl::StatusOr<LoadedExecutableRef> PjRtLoadedExecutable::Create(
405437 TF_ASSIGN_OR_RETURN (
406438 auto result_memory_kinds,
407439 GetFirstModuleOutputMemoryKinds (pjrt_loaded_executable.get ()));
408- TF_ASSIGN_OR_RETURN (auto output_layouts,
409- GetFirstModuleOutputLayouts (
410- pjrt_loaded_executable.get (), output_layout_modes));
440+ // Obtaining output layout modes and output layouts directly from
441+ // `PjRtLoadedExecutable` may fail because the currently PjRt
442+ // implementations often fetch and serialize the optimized HLO. For now, we
443+ // gracefully handle it by omitting output layouts at creation time and
444+ // using output `PjRtBuffer`'s concrete layouts.
445+ // TODO(hyeontaek): Add a way to obtain output layout modes and
446+ // `PjRtLoadedExecutable::GetOutputLayouts()` without causing the optimized
447+ // HLO to be serialized and fetched.
448+ std::optional<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
449+ output_layouts;
450+ if (output_layout_modes.ok ()) {
451+ absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
452+ first_module_output_layouts = GetFirstModuleOutputLayouts (
453+ pjrt_loaded_executable.get (), *output_layout_modes);
454+ if (first_module_output_layouts.ok ()) {
455+ output_layouts = *std::move (first_module_output_layouts);
456+ }
457+ }
411458 return CreateInternal (
412459 client, std::move (pjrt_loaded_executable),
413460 shape_partial_info.element_types , shape_partial_info.dimensions ,
@@ -423,7 +470,8 @@ absl::StatusOr<LoadedExecutableRef> PjRtLoadedExecutable::CreateInternal(
423470 absl::Span<const xla::DimensionVector> result_dimensions,
424471 const std::optional<xla::HloSharding>& result_hlo_sharding,
425472 const std::optional<std::vector<absl::string_view>>& result_memory_kinds,
426- const std::vector<std::shared_ptr<const xla::PjRtLayout>>& output_layouts,
473+ const std::optional<std::vector<std::shared_ptr<const xla::PjRtLayout>>>&
474+ output_layouts,
427475 std::vector<tsl::RCReference<LoadedHostCallback>> loaded_host_callbacks,
428476 DeviceListRef executable_devices) {
429477 // For jit(pmap(...)), the device assignment (passed as `executable_devices`)
@@ -596,7 +644,8 @@ PjRtLoadedExecutable::PjRtLoadedExecutable(
596644 host_send_recv_callbacks,
597645 std::vector<DType> output_dtypes, std::vector<Shape> output_shapes,
598646 std::vector<ShardingRef> output_shardings,
599- std::vector<std::shared_ptr<const xla::PjRtLayout>> output_layouts)
647+ std::optional<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
648+ output_layouts)
600649 : client_(client),
601650 pjrt_loaded_executable_ (std::move(pjrt_loaded_executable)),
602651 devices_(std::move(devices)),
@@ -812,6 +861,41 @@ PjRtLoadedExecutable::Execute(absl::Span<ArrayRef> args,
812861 // memory_kind shares the same Sharding object.
813862 absl::flat_hash_map<MemoryKind, ShardingRef> single_device_shardings;
814863
864+ std::vector<std::shared_ptr<const xla::PjRtLayout>> layouts;
865+ layouts.reserve (num_outputs);
866+ if (output_layouts_.has_value ()) {
867+ // TODO(hyeontaek): Once we can get `output_layouts_` reliably, only keep
868+ // this path.
869+ layouts = *output_layouts_;
870+ } else if (!pjrt_outputs.empty ()) {
871+ for (int i = 0 ; i < num_outputs; ++i) {
872+ auto layout = output_dtypes_[i].kind () == xla::ifrt::DType::kToken
873+ ? std::make_shared<xla::PjRtLayout>(xla::Layout ())
874+ : pjrt_outputs.front ()[i]->layout ();
875+ layouts.push_back (std::move (layout));
876+ }
877+ } else {
878+ auto maybe_layouts = GetOutputLayouts ();
879+ if (absl::IsUnimplemented (maybe_layouts.status ())) {
880+ for (int i = 0 ; i < num_outputs; ++i) {
881+ std::shared_ptr<const xla::PjRtLayout> layout;
882+ if (output_dtypes_[i].kind () == xla::ifrt::DType::kToken ) {
883+ layout = std::make_shared<xla::PjRtLayout>(xla::Layout ());
884+ } else {
885+ TF_ASSIGN_OR_RETURN (layout,
886+ client_->GetDefaultPjRtLayout (
887+ output_dtypes_[i], output_shapes_[i].dims (),
888+ devices_->devices ().front (),
889+ output_shardings_[i]->memory_kind ()));
890+ }
891+ layouts.push_back (std::move (layout));
892+ }
893+ } else {
894+ TF_RETURN_IF_ERROR (maybe_layouts.status ());
895+ layouts = *std::move (maybe_layouts);
896+ }
897+ }
898+
815899 for (int i = 0 ; i < num_outputs; ++i) {
816900 PjRtArray::PjRtBuffers buffers;
817901 buffers.reserve (num_computations);
@@ -852,7 +936,7 @@ PjRtLoadedExecutable::Execute(absl::Span<ArrayRef> args,
852936 }
853937 outputs.push_back (*PjRtArray::Create (
854938 client_, output_dtypes_[i], output_shapes_[i], *std::move (sharding),
855- std::move (buffers), output_layouts_ [i]));
939+ std::move (buffers), std::move (layouts [i]) ));
856940 }
857941
858942 ExecuteResult result;
0 commit comments