Skip to content

Commit 77643f3

Browse files
hyeontaekGoogle-ML-Automation
authored andcommitted
[PjRt-IFRT] Temporary workaround for output layout handling
PjRt-IFRT directly or indirectly fetched optimized HLO to get the output layout mode and output layouts. This seems to introduce a regression in some jobs that use PJRT C API and have a too large serialized HLO (> 2 GiB). As a workaround, PjRt-IFRT gracefully handles output layout mode and layout discovery errors, and falls back to concrete layouts that are directly obtained from output `PjRtBuffer`s, should give the same behavior before/after the default layout handling change. Further changes will follow to discover default layout modes and layouts without going through `PjRtLoadedExecutable::GetHloModules()`. PiperOrigin-RevId: 820785277
1 parent 43eb396 commit 77643f3

File tree

2 files changed

+112
-25
lines changed

2 files changed

+112
-25
lines changed

xla/python/pjrt_ifrt/pjrt_executable.cc

Lines changed: 106 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -296,17 +296,34 @@ absl::StatusOr<LoadedExecutableRef> PjRtLoadedExecutable::Create(
296296
TF_ASSIGN_OR_RETURN(
297297
auto result_memory_kinds,
298298
GetFirstModuleOutputMemoryKinds(pjrt_loaded_executable.get()));
299-
TF_ASSIGN_OR_RETURN(auto hlo_modules,
300-
pjrt_loaded_executable->GetHloModules());
301-
if (hlo_modules.empty()) {
302-
return FailedPrecondition("Requires at least one HloModule.");
303-
}
304-
TF_ASSIGN_OR_RETURN(std::vector<xla::LayoutMode> output_layout_modes,
305-
GetLayoutModes(*hlo_modules.front(), "out_layout_modes",
306-
result_element_types.size()));
307-
TF_ASSIGN_OR_RETURN(auto output_layouts,
308-
GetFirstModuleOutputLayouts(pjrt_loaded_executable.get(),
309-
output_layout_modes));
299+
// Obtaining output layout modes and output layouts directly from
300+
// `PjRtLoadedExecutable` may fail because the currently PjRt implementations
301+
// often fetch and serialize the optimized HLO. For now, we gracefully
302+
// handle it by omitting output layouts at creation time and using output
303+
// `PjRtBuffer`'s concrete layouts.
304+
// TODO(hyeontaek): Add a way to obtain output layout modes and
305+
// `PjRtLoadedExecutable::GetOutputLayouts()` without causing the optimized
306+
// HLO to be serialized and fetched.
307+
std::optional<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
308+
output_layouts;
309+
absl::StatusOr<std::vector<std::shared_ptr<HloModule>>> hlo_modules =
310+
pjrt_loaded_executable->GetHloModules();
311+
if (hlo_modules.ok()) {
312+
if (hlo_modules->empty()) {
313+
return FailedPrecondition("Requires at least one HloModule.");
314+
}
315+
absl::StatusOr<std::vector<xla::LayoutMode>> output_layout_modes =
316+
GetLayoutModes(*hlo_modules->front(), "out_layout_modes",
317+
result_element_types.size());
318+
if (output_layout_modes.ok()) {
319+
absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
320+
first_module_output_layouts = GetFirstModuleOutputLayouts(
321+
pjrt_loaded_executable.get(), *output_layout_modes);
322+
if (first_module_output_layouts.ok()) {
323+
output_layouts = *std::move(first_module_output_layouts);
324+
}
325+
}
326+
}
310327
return CreateInternal(client, std::move(pjrt_loaded_executable),
311328
result_element_types, result_dimensions,
312329
/*result_hlo_sharding=*/std::nullopt,
@@ -352,8 +369,8 @@ absl::StatusOr<LoadedExecutableRef> PjRtLoadedExecutable::Create(
352369
// will use the MLIR as scratch space, or possibly even deallocate it.
353370
TF_ASSIGN_OR_RETURN(const std::vector<xla::Shape> result_shapes,
354371
ResultShapesOfModule(module));
355-
TF_ASSIGN_OR_RETURN(const std::vector<xla::LayoutMode> output_layout_modes,
356-
GetOutputLayoutModes(module));
372+
absl::StatusOr<std::vector<xla::LayoutMode>> output_layout_modes =
373+
GetOutputLayoutModes(module);
357374

358375
TF_ASSIGN_OR_RETURN(auto pjrt_loaded_executable,
359376
client->pjrt_client()->CompileAndLoad(
@@ -372,9 +389,24 @@ absl::StatusOr<LoadedExecutableRef> PjRtLoadedExecutable::Create(
372389
TF_ASSIGN_OR_RETURN(
373390
auto result_memory_kinds,
374391
GetFirstModuleOutputMemoryKinds(pjrt_loaded_executable.get()));
375-
TF_ASSIGN_OR_RETURN(auto output_layouts,
376-
GetFirstModuleOutputLayouts(
377-
pjrt_loaded_executable.get(), output_layout_modes));
392+
// Obtaining output layout modes and output layouts directly from
393+
// `PjRtLoadedExecutable` may fail because the currently PjRt
394+
// implementations often fetch and serialize the optimized HLO. For now, we
395+
// gracefully handle it by omitting output layouts at creation time and
396+
// using output `PjRtBuffer`'s concrete layouts.
397+
// TODO(hyeontaek): Add a way to obtain output layout modes and
398+
// `PjRtLoadedExecutable::GetOutputLayouts()` without causing the optimized
399+
// HLO to be serialized and fetched.
400+
std::optional<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
401+
output_layouts;
402+
if (output_layout_modes.ok()) {
403+
absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
404+
first_module_output_layouts = GetFirstModuleOutputLayouts(
405+
pjrt_loaded_executable.get(), *output_layout_modes);
406+
if (first_module_output_layouts.ok()) {
407+
output_layouts = *std::move(first_module_output_layouts);
408+
}
409+
}
378410
return CreateInternal(client, std::move(pjrt_loaded_executable),
379411
result_element_types, result_dimensions,
380412
/*result_hlo_sharding=*/std::nullopt,
@@ -405,9 +437,24 @@ absl::StatusOr<LoadedExecutableRef> PjRtLoadedExecutable::Create(
405437
TF_ASSIGN_OR_RETURN(
406438
auto result_memory_kinds,
407439
GetFirstModuleOutputMemoryKinds(pjrt_loaded_executable.get()));
408-
TF_ASSIGN_OR_RETURN(auto output_layouts,
409-
GetFirstModuleOutputLayouts(
410-
pjrt_loaded_executable.get(), output_layout_modes));
440+
// Obtaining output layout modes and output layouts directly from
441+
// `PjRtLoadedExecutable` may fail because the currently PjRt
442+
// implementations often fetch and serialize the optimized HLO. For now, we
443+
// gracefully handle it by omitting output layouts at creation time and
444+
// using output `PjRtBuffer`'s concrete layouts.
445+
// TODO(hyeontaek): Add a way to obtain output layout modes and
446+
// `PjRtLoadedExecutable::GetOutputLayouts()` without causing the optimized
447+
// HLO to be serialized and fetched.
448+
std::optional<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
449+
output_layouts;
450+
if (output_layout_modes.ok()) {
451+
absl::StatusOr<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
452+
first_module_output_layouts = GetFirstModuleOutputLayouts(
453+
pjrt_loaded_executable.get(), *output_layout_modes);
454+
if (first_module_output_layouts.ok()) {
455+
output_layouts = *std::move(first_module_output_layouts);
456+
}
457+
}
411458
return CreateInternal(
412459
client, std::move(pjrt_loaded_executable),
413460
shape_partial_info.element_types, shape_partial_info.dimensions,
@@ -423,7 +470,8 @@ absl::StatusOr<LoadedExecutableRef> PjRtLoadedExecutable::CreateInternal(
423470
absl::Span<const xla::DimensionVector> result_dimensions,
424471
const std::optional<xla::HloSharding>& result_hlo_sharding,
425472
const std::optional<std::vector<absl::string_view>>& result_memory_kinds,
426-
const std::vector<std::shared_ptr<const xla::PjRtLayout>>& output_layouts,
473+
const std::optional<std::vector<std::shared_ptr<const xla::PjRtLayout>>>&
474+
output_layouts,
427475
std::vector<tsl::RCReference<LoadedHostCallback>> loaded_host_callbacks,
428476
DeviceListRef executable_devices) {
429477
// For jit(pmap(...)), the device assignment (passed as `executable_devices`)
@@ -596,7 +644,8 @@ PjRtLoadedExecutable::PjRtLoadedExecutable(
596644
host_send_recv_callbacks,
597645
std::vector<DType> output_dtypes, std::vector<Shape> output_shapes,
598646
std::vector<ShardingRef> output_shardings,
599-
std::vector<std::shared_ptr<const xla::PjRtLayout>> output_layouts)
647+
std::optional<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
648+
output_layouts)
600649
: client_(client),
601650
pjrt_loaded_executable_(std::move(pjrt_loaded_executable)),
602651
devices_(std::move(devices)),
@@ -812,6 +861,41 @@ PjRtLoadedExecutable::Execute(absl::Span<ArrayRef> args,
812861
// memory_kind shares the same Sharding object.
813862
absl::flat_hash_map<MemoryKind, ShardingRef> single_device_shardings;
814863

864+
std::vector<std::shared_ptr<const xla::PjRtLayout>> layouts;
865+
layouts.reserve(num_outputs);
866+
if (output_layouts_.has_value()) {
867+
// TODO(hyeontaek): Once we can get `output_layouts_` reliably, only keep
868+
// this path.
869+
layouts = *output_layouts_;
870+
} else if (!pjrt_outputs.empty()) {
871+
for (int i = 0; i < num_outputs; ++i) {
872+
auto layout = output_dtypes_[i].kind() == xla::ifrt::DType::kToken
873+
? std::make_shared<xla::PjRtLayout>(xla::Layout())
874+
: pjrt_outputs.front()[i]->layout();
875+
layouts.push_back(std::move(layout));
876+
}
877+
} else {
878+
auto maybe_layouts = GetOutputLayouts();
879+
if (absl::IsUnimplemented(maybe_layouts.status())) {
880+
for (int i = 0; i < num_outputs; ++i) {
881+
std::shared_ptr<const xla::PjRtLayout> layout;
882+
if (output_dtypes_[i].kind() == xla::ifrt::DType::kToken) {
883+
layout = std::make_shared<xla::PjRtLayout>(xla::Layout());
884+
} else {
885+
TF_ASSIGN_OR_RETURN(layout,
886+
client_->GetDefaultPjRtLayout(
887+
output_dtypes_[i], output_shapes_[i].dims(),
888+
devices_->devices().front(),
889+
output_shardings_[i]->memory_kind()));
890+
}
891+
layouts.push_back(std::move(layout));
892+
}
893+
} else {
894+
TF_RETURN_IF_ERROR(maybe_layouts.status());
895+
layouts = *std::move(maybe_layouts);
896+
}
897+
}
898+
815899
for (int i = 0; i < num_outputs; ++i) {
816900
PjRtArray::PjRtBuffers buffers;
817901
buffers.reserve(num_computations);
@@ -852,7 +936,7 @@ PjRtLoadedExecutable::Execute(absl::Span<ArrayRef> args,
852936
}
853937
outputs.push_back(*PjRtArray::Create(
854938
client_, output_dtypes_[i], output_shapes_[i], *std::move(sharding),
855-
std::move(buffers), output_layouts_[i]));
939+
std::move(buffers), std::move(layouts[i])));
856940
}
857941

858942
ExecuteResult result;

xla/python/pjrt_ifrt/pjrt_executable.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,8 @@ class PjRtLoadedExecutable final
339339
absl::Span<const xla::DimensionVector> result_dimensions,
340340
const std::optional<xla::HloSharding>& result_hlo_sharding,
341341
const std::optional<std::vector<absl::string_view>>& result_memory_kinds,
342-
const std::vector<std::shared_ptr<const xla::PjRtLayout>>& output_layouts,
342+
const std::optional<std::vector<std::shared_ptr<const xla::PjRtLayout>>>&
343+
output_layouts,
343344
std::vector<tsl::RCReference<LoadedHostCallback>> loaded_host_callbacks,
344345
DeviceListRef executable_devices);
345346

@@ -353,7 +354,8 @@ class PjRtLoadedExecutable final
353354
host_send_recv_callbacks,
354355
std::vector<DType> output_dtypes, std::vector<Shape> output_shapes,
355356
std::vector<ShardingRef> output_shardings,
356-
std::vector<std::shared_ptr<const xla::PjRtLayout>> output_layouts);
357+
std::optional<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
358+
output_layouts);
357359

358360
PjRtClient* client_;
359361
std::shared_ptr<xla::PjRtLoadedExecutable> pjrt_loaded_executable_;
@@ -372,7 +374,8 @@ class PjRtLoadedExecutable final
372374
std::vector<DType> output_dtypes_;
373375
std::vector<Shape> output_shapes_;
374376
std::vector<ShardingRef> output_shardings_;
375-
std::vector<std::shared_ptr<const xla::PjRtLayout>> output_layouts_;
377+
std::optional<std::vector<std::shared_ptr<const xla::PjRtLayout>>>
378+
output_layouts_;
376379
const xla::ifrt::UserContextRef user_context_;
377380
};
378381

0 commit comments

Comments
 (0)