@@ -657,6 +657,20 @@ absl::Status PropagateParameterLayoutToUsers(const HloInstruction* instruction,
657657 return absl::OkStatus ();
658658}
659659
660+ absl::Status ResetMemorySpaceInLayout (ShapeLayout& mutable_shape_layout) {
661+ Shape shape = mutable_shape_layout.shape ();
662+ TF_RETURN_IF_ERROR (ShapeUtil::ForEachMutableSubshapeWithStatus (
663+ &shape, [](Shape* subshape, const ShapeIndex& shape_index) {
664+ if (subshape->has_layout () && subshape->IsArray ()) {
665+ subshape->mutable_layout ()->set_memory_space (
666+ Layout::kDefaultMemorySpace );
667+ }
668+ return absl::OkStatus ();
669+ }));
670+ TF_RETURN_IF_ERROR (mutable_shape_layout.CopyLayoutFromShape (shape));
671+ return absl::OkStatus ();
672+ }
673+
660674} // namespace
661675
662676absl::Status LayoutAssignment::AddMandatoryConstraints (
@@ -693,27 +707,18 @@ absl::Status LayoutAssignment::AddMandatoryConstraints(
693707 entry_computation_layout_->AnyLayoutSet ()) ||
694708 (conditional_mismatch_.count (constraints->computation ()) == 0 &&
695709 constraints->computation_constraint ().parameter_layout_is_set ())) {
696- const ShapeLayout& parameter_layout =
710+ ShapeLayout parameter_layout =
697711 constraints->computation_layout ().parameter_layout (
698712 instruction->parameter_number ());
699713 // Allow some paramter/result layouts to be unset in the entry
700714 // computation.
701715 if (parameter_layout.AnyLayoutIsSet ()) {
716+ // Clear out memory space in layout. Host offloader will do the
717+ // analysis later.
718+ TF_RETURN_IF_ERROR (ResetMemorySpaceInLayout (parameter_layout));
702719 // Parameter layouts must match the respective layout in
703720 // ComputationLayout, if there is one.
704721 Shape param_shape = parameter_layout.shape ();
705- // Clear out memory space in layout. Host offloader will do the
706- // analysis later.
707- TF_RETURN_IF_ERROR (ShapeUtil::ForEachMutableSubshapeWithStatus (
708- ¶m_shape, [](Shape* subshape, const ShapeIndex& index) {
709- if (!subshape->has_layout () || !subshape->IsArray ()) {
710- return absl::OkStatus ();
711- }
712- subshape->mutable_layout ()->set_memory_space (
713- Layout::kDefaultMemorySpace );
714- return absl::OkStatus ();
715- }));
716-
717722 TF_RETURN_IF_ERROR (SetInstructionLayout (param_shape, instruction));
718723 if (reverse_computation_order_) {
719724 TF_RETURN_IF_ERROR (PropagateParameterLayoutToUsers (
@@ -2033,16 +2038,7 @@ absl::Status LayoutAssignment::PropagateResultConstraint(
20332038 // Clear out memory space in layout for entry computation root. Host offloader
20342039 // will do the analysis later and add back the memory space for host outputs.
20352040 if (constraints->computation ()->IsEntryComputation ()) {
2036- Shape result_shape = result_layout.shape ();
2037- TF_RETURN_IF_ERROR (ShapeUtil::ForEachMutableSubshapeWithStatus (
2038- &result_shape, [](Shape* subshape, const ShapeIndex& shape_index) {
2039- if (subshape->has_layout () && subshape->IsArray ()) {
2040- subshape->mutable_layout ()->set_memory_space (
2041- Layout::kDefaultMemorySpace );
2042- }
2043- return absl::OkStatus ();
2044- }));
2045- TF_RETURN_IF_ERROR (result_layout.CopyLayoutFromShape (result_shape));
2041+ TF_RETURN_IF_ERROR (ResetMemorySpaceInLayout (result_layout));
20462042 }
20472043
20482044 // Propagate the use constraint of the root instruction up to the logical
@@ -2232,25 +2228,29 @@ absl::Status LayoutAssignment::AssignLayouts(LayoutConstraints& constraints) {
22322228 // layout constraint.
22332229 if (constraints.ResultLayout () != nullptr &&
22342230 constraints.ResultLayout ()->LayoutIsSet ()) {
2231+ ShapeLayout result_layout = *constraints.ResultLayout ();
2232+ // Clear out memory space in layout. Host offloader will do the
2233+ // analysis later.
2234+ TF_RETURN_IF_ERROR (ResetMemorySpaceInLayout (result_layout));
22352235 // Layout assignment at this point only does minor-to-major assignment so
22362236 // tiling info should be ignored here for comparison.
22372237 VLOG (5 ) << " Computation result layout needs root copying\n " ;
2238- if (!constraints. ResultLayout ()-> MatchesLayoutInShape (
2238+ if (!result_layout. MatchesLayoutInShape (
22392239 computation->root_instruction ()->shape (),
22402240 /* minor_to_major_only=*/ true )) {
22412241 TF_ASSIGN_OR_RETURN (
22422242 HloInstruction * new_root,
2243- CreateCopyWithNewLayout (constraints. ResultLayout ()-> shape (),
2243+ CreateCopyWithNewLayout (result_layout. shape (),
22442244 computation->root_instruction ()));
22452245 computation->set_root_instruction (new_root);
22462246 } else {
22472247 // Copy the tiling info/tail_padding_alignment_in_elements specified in
22482248 // result layout.
2249- auto copy_tiling = [&constraints ](xla::Shape* subshape,
2250- const xla::ShapeIndex& index) {
2249+ auto copy_tiling = [&result_layout ](xla::Shape* subshape,
2250+ const xla::ShapeIndex& index) {
22512251 if (subshape->IsArray ()) {
2252- const Shape& result_shape = ShapeUtil::GetSubshape (
2253- constraints. ResultLayout ()-> shape (), index);
2252+ const Shape& result_shape =
2253+ ShapeUtil::GetSubshape (result_layout. shape (), index);
22542254 if (result_shape.layout ().tiles_size () != 0 ) {
22552255 subshape->mutable_layout ()->mutable_tiles ()->assign (
22562256 result_shape.layout ().tiles ().begin (),
0 commit comments