Skip to content

Commit d808ec6

Browse files
authored
[Auto Parallel] fix bug for transpose spmd (#69862)
1 parent b4e2d5d commit d808ec6

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

paddle/phi/infermeta/spmd_rules/transpose.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,14 @@ SpmdInfo TransposeInferSpmd(const DistMetaTensor& x,
8787
GetDimsMappingForAxes(out_axes, axis_to_dim_map);
8888

8989
auto x_dist_attr_dst = CopyTensorDistAttrForOutput(x_dist_attr_src);
90+
x_dist_attr_dst.set_partial_status(x_dist_attr_src.partial_status());
9091
x_dist_attr_dst.set_dims_mapping(x_dims_mapping);
9192

9293
// initialize output dist_attr's process_mesh, batch_dim and dynamic dims with
9394
// input dist_attr.
9495
TensorDistAttr out_dist_attr = CopyTensorDistAttrForOutput(x_dist_attr_src);
9596
out_dist_attr.set_dims_mapping(out_dims_mapping);
96-
out_dist_attr.set_partial_status(x_dist_attr_src.partial_status());
97+
out_dist_attr.set_partial_status(x_dist_attr_dst.partial_status());
9798

9899
VLOG(4) << "TransposeInferSpmd:";
99100
VLOG(4) << "Input: shape: [" << str_join(x_shape) << "] "

0 commit comments

Comments
 (0)