Skip to content

Commit be5b1af

Browse files
jaro-sevcikGoogle-ML-Automation
authored andcommitted
PR #20426: Layout assignment: Reset memory space in result layout
Imported from GitHub PR #20426 Layout assignment should not set any memory space on any of the instructions even if the entry computation layout has non-default memory space. At one place, the memory space was leaking (causing weight offloading crashes on real models), this patch addresses that. Drive-by: Introduce a helper function for the copy-pasted implementations of resetting the memory space in a layout. Copybara import of the project: -- 29bfdd8 by Jaroslav Sevcik <[email protected]>: Reset memory space and result layout Merging this change closes #20426 COPYBARA_INTEGRATE_REVIEW=#20426 from jaro-sevcik:scrub-memory-space-in-layout-assignment 29bfdd8 PiperOrigin-RevId: 707185192
1 parent 432da09 commit be5b1af

File tree

2 files changed

+82
-29
lines changed

2 files changed

+82
-29
lines changed

xla/service/layout_assignment.cc

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,20 @@ absl::Status PropagateParameterLayoutToUsers(const HloInstruction* instruction,
657657
return absl::OkStatus();
658658
}
659659

660+
absl::Status ResetMemorySpaceInLayout(ShapeLayout& mutable_shape_layout) {
661+
Shape shape = mutable_shape_layout.shape();
662+
TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus(
663+
&shape, [](Shape* subshape, const ShapeIndex& shape_index) {
664+
if (subshape->has_layout() && subshape->IsArray()) {
665+
subshape->mutable_layout()->set_memory_space(
666+
Layout::kDefaultMemorySpace);
667+
}
668+
return absl::OkStatus();
669+
}));
670+
TF_RETURN_IF_ERROR(mutable_shape_layout.CopyLayoutFromShape(shape));
671+
return absl::OkStatus();
672+
}
673+
660674
} // namespace
661675

662676
absl::Status LayoutAssignment::AddMandatoryConstraints(
@@ -693,27 +707,18 @@ absl::Status LayoutAssignment::AddMandatoryConstraints(
693707
entry_computation_layout_->AnyLayoutSet()) ||
694708
(conditional_mismatch_.count(constraints->computation()) == 0 &&
695709
constraints->computation_constraint().parameter_layout_is_set())) {
696-
const ShapeLayout& parameter_layout =
710+
ShapeLayout parameter_layout =
697711
constraints->computation_layout().parameter_layout(
698712
instruction->parameter_number());
699713
// Allow some paramter/result layouts to be unset in the entry
700714
// computation.
701715
if (parameter_layout.AnyLayoutIsSet()) {
716+
// Clear out memory space in layout. Host offloader will do the
717+
// analysis later.
718+
TF_RETURN_IF_ERROR(ResetMemorySpaceInLayout(parameter_layout));
702719
// Parameter layouts must match the respective layout in
703720
// ComputationLayout, if there is one.
704721
Shape param_shape = parameter_layout.shape();
705-
// Clear out memory space in layout. Host offloader will do the
706-
// analysis later.
707-
TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus(
708-
&param_shape, [](Shape* subshape, const ShapeIndex& index) {
709-
if (!subshape->has_layout() || !subshape->IsArray()) {
710-
return absl::OkStatus();
711-
}
712-
subshape->mutable_layout()->set_memory_space(
713-
Layout::kDefaultMemorySpace);
714-
return absl::OkStatus();
715-
}));
716-
717722
TF_RETURN_IF_ERROR(SetInstructionLayout(param_shape, instruction));
718723
if (reverse_computation_order_) {
719724
TF_RETURN_IF_ERROR(PropagateParameterLayoutToUsers(
@@ -2033,16 +2038,7 @@ absl::Status LayoutAssignment::PropagateResultConstraint(
20332038
// Clear out memory space in layout for entry computation root. Host offloader
20342039
// will do the analysis later and add back the memory space for host outputs.
20352040
if (constraints->computation()->IsEntryComputation()) {
2036-
Shape result_shape = result_layout.shape();
2037-
TF_RETURN_IF_ERROR(ShapeUtil::ForEachMutableSubshapeWithStatus(
2038-
&result_shape, [](Shape* subshape, const ShapeIndex& shape_index) {
2039-
if (subshape->has_layout() && subshape->IsArray()) {
2040-
subshape->mutable_layout()->set_memory_space(
2041-
Layout::kDefaultMemorySpace);
2042-
}
2043-
return absl::OkStatus();
2044-
}));
2045-
TF_RETURN_IF_ERROR(result_layout.CopyLayoutFromShape(result_shape));
2041+
TF_RETURN_IF_ERROR(ResetMemorySpaceInLayout(result_layout));
20462042
}
20472043

20482044
// Propagate the use constraint of the root instruction up to the logical
@@ -2232,25 +2228,29 @@ absl::Status LayoutAssignment::AssignLayouts(LayoutConstraints& constraints) {
22322228
// layout constraint.
22332229
if (constraints.ResultLayout() != nullptr &&
22342230
constraints.ResultLayout()->LayoutIsSet()) {
2231+
ShapeLayout result_layout = *constraints.ResultLayout();
2232+
// Clear out memory space in layout. Host offloader will do the
2233+
// analysis later.
2234+
TF_RETURN_IF_ERROR(ResetMemorySpaceInLayout(result_layout));
22352235
// Layout assignment at this point only does minor-to-major assignment so
22362236
// tiling info should be ignored here for comparison.
22372237
VLOG(5) << "Computation result layout needs root copying\n";
2238-
if (!constraints.ResultLayout()->MatchesLayoutInShape(
2238+
if (!result_layout.MatchesLayoutInShape(
22392239
computation->root_instruction()->shape(),
22402240
/*minor_to_major_only=*/true)) {
22412241
TF_ASSIGN_OR_RETURN(
22422242
HloInstruction * new_root,
2243-
CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
2243+
CreateCopyWithNewLayout(result_layout.shape(),
22442244
computation->root_instruction()));
22452245
computation->set_root_instruction(new_root);
22462246
} else {
22472247
// Copy the tiling info/tail_padding_alignment_in_elements specified in
22482248
// result layout.
2249-
auto copy_tiling = [&constraints](xla::Shape* subshape,
2250-
const xla::ShapeIndex& index) {
2249+
auto copy_tiling = [&result_layout](xla::Shape* subshape,
2250+
const xla::ShapeIndex& index) {
22512251
if (subshape->IsArray()) {
2252-
const Shape& result_shape = ShapeUtil::GetSubshape(
2253-
constraints.ResultLayout()->shape(), index);
2252+
const Shape& result_shape =
2253+
ShapeUtil::GetSubshape(result_layout.shape(), index);
22542254
if (result_shape.layout().tiles_size() != 0) {
22552255
subshape->mutable_layout()->mutable_tiles()->assign(
22562256
result_shape.layout().tiles().begin(),

xla/service/layout_assignment_test.cc

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,6 +1367,59 @@ ENTRY %CustomCallLayoutConstrainedTupleResult (p0: f32[4,4]) -> (f32[4,4]{1,0},
13671367
ExpectTupleLayoutIs(custom_call->shape(), {{1, 0}, {0, 1}});
13681368
}
13691369

1370+
TEST_F(LayoutAssignmentTest, MemorySpaceRemoved) {
1371+
const char* module_str = R"(
1372+
HloModule MixedHostDeviceResult
1373+
1374+
ENTRY %MixedHostDeviceResult {
1375+
%p0 = f32[4,4] parameter(0)
1376+
%d = f32[4,4]{1,0} custom-call(%p0), custom_call_target="MoveToDevice", metadata={preserve_layout=true}
1377+
ROOT %tuple = (f32[4,4], f32[4,4]) tuple(%p0, %d)
1378+
}
1379+
)";
1380+
TF_ASSERT_OK_AND_ASSIGN(
1381+
std::unique_ptr<VerifiedHloModule> m,
1382+
ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
1383+
ComputationLayout computation_layout = m->entry_computation_layout();
1384+
1385+
// Set the parameter to be in host memory.
1386+
*computation_layout.mutable_parameter_layout(0) =
1387+
ShapeLayout(ShapeUtil::MakeShapeWithDenseLayout(
1388+
F32, {4, 4}, {1, 0}, /*tiles=*/{},
1389+
/*tail_padding_alignment_in_elements=*/1, /*element_size_in_bits=*/0,
1390+
Layout::kHostMemorySpace));
1391+
// Set one result component to be in host memory, the other one on device.
1392+
// Also make sure to request incompatible result layout so that the layout
1393+
// assignment pass has to copy the layout from the entry computation layout.
1394+
*computation_layout.mutable_result_layout() =
1395+
ShapeLayout(ShapeUtil::MakeTupleShape(
1396+
{ShapeUtil::MakeShapeWithDenseLayout(
1397+
F32, {4, 4}, {1, 0}, /*tiles=*/{},
1398+
/*tail_padding_alignment_in_elements=*/1,
1399+
/*element_size_in_bits=*/0, Layout::kHostMemorySpace),
1400+
ShapeUtil::MakeShapeWithDenseLayout(
1401+
F32, {4, 4}, {0, 1}, /*tiles=*/{},
1402+
/*tail_padding_alignment_in_elements=*/1,
1403+
/*element_size_in_bits=*/0, Layout::kDefaultMemorySpace)}));
1404+
AssignLayouts(m.get(), &computation_layout);
1405+
1406+
// Verify that the memory space did not leak from the entry computation layout
1407+
// to the parameter or to the result.
1408+
Shape result_shape = m->entry_computation()->root_instruction()->shape();
1409+
EXPECT_EQ(
1410+
ShapeUtil::GetTupleElementShape(result_shape, 0).layout().memory_space(),
1411+
Layout::kDefaultMemorySpace);
1412+
EXPECT_EQ(
1413+
ShapeUtil::GetTupleElementShape(result_shape, 1).layout().memory_space(),
1414+
Layout::kDefaultMemorySpace);
1415+
1416+
const HloInstruction* parameter = FindInstruction(m.get(), "p0");
1417+
EXPECT_EQ(parameter->shape().layout().memory_space(),
1418+
Layout::kDefaultMemorySpace);
1419+
1420+
ExpectTupleLayoutIs(result_shape, {{1, 0}, {0, 1}});
1421+
}
1422+
13701423
absl::Status AssignLayoutsToComputation(
13711424
HloModule* m, ChannelLayoutConstraints* channel_constraints = nullptr) {
13721425
if (!m->entry_computation_layout().result_layout().LayoutIsSet()) {

0 commit comments

Comments
 (0)