Skip to content

Commit 09e1fa5

Browse files
authored
[Auto Parallel] Add spmd_rule about sharding on the same tensor dim by many mesh dim for reshape (#74352)
* Add spmd_rule about sharding on the same tensor dim by many mesh dim for reshape * Fix order * refine code && tests case * fix typos * fix typos * Fix paddle::get usage
1 parent c096ce1 commit 09e1fa5

File tree

9 files changed

+476
-21
lines changed

9 files changed

+476
-21
lines changed

paddle/phi/core/distributed/auto_parallel/dist_attr.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ void TensorDistAttr::set_default_dynamic_dims(
176176
dynamic_dims_ = std::vector<bool>(tensor_shape.size(), false);
177177
}
178178

179+
void TensorDistAttr::set_default_dynamic_dims(int64_t tensor_shape_size) {
180+
dynamic_dims_ = std::vector<bool>(tensor_shape_size, false);
181+
}
182+
179183
void TensorDistAttr::mark_annotated(const std::string& name) {
180184
auto result = std::find(std::begin(fields_), std::end(fields_), name);
181185
if (result != std::end(fields_)) {

paddle/phi/core/distributed/auto_parallel/dist_attr.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,8 @@ class TEST_API TensorDistAttr {
156156

157157
void set_default_dynamic_dims(const std::vector<int64_t>& tensor_shape);
158158

159+
void set_default_dynamic_dims(int64_t tensor_shape_size);
160+
159161
const std::map<std::string, bool>& annotated() const { return annotated_; }
160162

161163
void set_annotated(const std::map<std::string, bool>& annotated);

paddle/phi/infermeta/spmd_rules/dim_trans.cc

Lines changed: 255 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ std::shared_ptr<DimTrans> make_split(const std::shared_ptr<DimTrans> dim,
161161
// map between from idx in shape to new_shape
162162
std::vector<int64_t> idx_map(shape.size(), -1);
163163
for (int i = 0, n = static_cast<int>(shape.size()); i < n; ++i) {
164-
if (shape[id] != 1) {
164+
if (shape[i] != 1) {
165165
idx_map[i] = static_cast<int64_t>(new_shape.size());
166166
new_shape.emplace_back(shape[i]);
167167
}
@@ -272,6 +272,139 @@ std::vector<std::shared_ptr<DimTrans>> GetDimTrans(
272272
return ret_dim_trans;
273273
}
274274

275+
std::vector<std::shared_ptr<DimTrans>> GetDimTransCoShard(
276+
const std::shared_ptr<DimTrans> dim_trans,
277+
const std::vector<int64_t>& input_shape,
278+
const std::vector<int64_t>& mesh_shape,
279+
const std::vector<std::vector<int64_t>>& input_dims_mapping,
280+
const std::set<int64_t>& sharded_input_dims,
281+
std::vector<std::vector<bool>>* shardable,
282+
std::set<int64_t>* seen_dims) {
283+
DimTrans::Type type = dim_trans->type();
284+
std::vector<std::shared_ptr<DimTrans>> ret_dim_trans;
285+
286+
if (type == DimTrans::Type::INPUTDIM) {
287+
std::shared_ptr<InputDim> inputdim =
288+
std::dynamic_pointer_cast<InputDim>(dim_trans);
289+
int64_t dim = inputdim->input_dim();
290+
seen_dims->insert(dim);
291+
292+
if (sharded_input_dims.count(dim) > 0) {
293+
ret_dim_trans.push_back(dim_trans);
294+
}
295+
} else if (type == DimTrans::Type::FLATTEN) {
296+
std::shared_ptr<Flatten> flatten =
297+
std::dynamic_pointer_cast<Flatten>(dim_trans);
298+
const std::vector<std::shared_ptr<DimTrans>>& inputs = flatten->inputs();
299+
300+
int64_t nmesh = (*shardable)[0].size(); // NOLINT
301+
int64_t mesh_shape_prod = 1;
302+
303+
int last_shard_idx = -1;
304+
int64_t first_shard_idx = -1;
305+
int64_t first_sharded_shape = -1;
306+
307+
for (int i = 0, n = static_cast<int>(inputs.size()); i < n; ++i) {
308+
std::shared_ptr<DimTrans> input = inputs[i];
309+
if (input->type() != DimTrans::Type::INPUTDIM) {
310+
break;
311+
}
312+
std::shared_ptr<InputDim> inputdim =
313+
std::dynamic_pointer_cast<InputDim>(input);
314+
if (sharded_input_dims.count(inputdim->input_dim()) > 0) {
315+
if (first_shard_idx == -1) {
316+
first_shard_idx = i;
317+
first_sharded_shape = input_shape[inputdim->input_dim()];
318+
}
319+
for (const auto& dim : input_dims_mapping[inputdim->input_dim()]) {
320+
mesh_shape_prod *= mesh_shape[dim];
321+
}
322+
if (first_sharded_shape % mesh_shape_prod == 0) {
323+
ret_dim_trans.push_back(inputdim);
324+
} else {
325+
break;
326+
}
327+
} else {
328+
break;
329+
}
330+
last_shard_idx = i;
331+
}
332+
333+
for (int i = last_shard_idx + 1, n = static_cast<int>(inputs.size()); i < n;
334+
i++) {
335+
std::shared_ptr<DimTrans> input = inputs[i];
336+
if (input->type() == DimTrans::Type::INPUTDIM) {
337+
std::shared_ptr<InputDim> inputdim =
338+
std::dynamic_pointer_cast<InputDim>(input);
339+
(*shardable)[inputdim->input_dim()].assign(nmesh, false);
340+
}
341+
342+
GetDimTransCoShard(input,
343+
input_shape,
344+
mesh_shape,
345+
input_dims_mapping,
346+
sharded_input_dims,
347+
shardable,
348+
seen_dims);
349+
}
350+
} else if (type == DimTrans::Type::SPLIT) {
351+
std::shared_ptr<Split> split = std::dynamic_pointer_cast<Split>(dim_trans);
352+
std::vector<std::shared_ptr<DimTrans>> dims =
353+
GetDimTransCoShard(split->input(),
354+
input_shape,
355+
mesh_shape,
356+
input_dims_mapping,
357+
sharded_input_dims,
358+
shardable,
359+
seen_dims);
360+
int64_t ret_size = split->local_split_shape_value();
361+
362+
if (split->split_id() == 0) {
363+
int64_t mesh_shape_prod = 1;
364+
int64_t first_shard_idx = -1;
365+
int64_t first_sharded_shape = -1;
366+
for (const auto& dim : dims) {
367+
PADDLE_ENFORCE_EQ(dim->type(),
368+
DimTrans::Type::INPUTDIM,
369+
common::errors::InvalidArgument(
370+
"The returned dim_trans must be INPUTDIM."));
371+
std::shared_ptr<InputDim> inputdim =
372+
std::dynamic_pointer_cast<InputDim>(dim);
373+
int64_t nmesh = static_cast<int64_t>(mesh_shape.size());
374+
int64_t input_axis = inputdim->input_dim();
375+
376+
// Check whether the sharded dim can be sharded on
377+
// each mesh dimension. The dimension should be
378+
// divisible by the mesh size that it is sharded on
379+
for (int64_t imesh = 0; imesh < nmesh; imesh++) {
380+
(*shardable)[input_axis][imesh] = (ret_size % mesh_shape[imesh] == 0);
381+
}
382+
383+
if (first_shard_idx == -1) {
384+
first_shard_idx = input_axis;
385+
first_sharded_shape = input_shape[input_axis];
386+
}
387+
388+
if (sharded_input_dims.count(input_axis) > 0) {
389+
for (const auto& dim : input_dims_mapping[input_axis]) {
390+
mesh_shape_prod *= mesh_shape[dim];
391+
}
392+
if ((ret_size % mesh_shape_prod == 0) &&
393+
(first_sharded_shape % mesh_shape_prod == 0)) {
394+
ret_dim_trans.push_back(dim);
395+
} else {
396+
break;
397+
}
398+
} else {
399+
break;
400+
}
401+
}
402+
}
403+
} else if (type == DimTrans::Type::SINGLETON) {
404+
}
405+
return ret_dim_trans;
406+
}
407+
275408
void GetUsedInputDim(const std::shared_ptr<DimTrans> dim_trans,
276409
std::set<int64_t>* seen_dims) {
277410
if (dim_trans->type() == DimTrans::Type::INPUTDIM) {
@@ -311,6 +444,27 @@ InferFromDimTrans(const DistMetaTensor& input_spec,
311444
return InferFromDimTrans(input_spec, input_shape, dim_trans);
312445
}
313446

447+
std::tuple<std::vector<std::vector<int64_t>>, std::vector<std::vector<int64_t>>>
448+
InferFromDimTransCoShard(
449+
const DistMetaTensor& input_spec,
450+
const std::vector<std::shared_ptr<DimTrans>>& dim_trans) {
451+
auto input_shape = phi::vectorize(input_spec.dims());
452+
// deal with reshape xshape in dynamic
453+
if (input_shape[0] == 0 &&
454+
input_shape.size() !=
455+
input_spec.dist_attr().multi_dims_mapping().size()) {
456+
input_shape.erase(input_shape.begin());
457+
}
458+
PADDLE_ENFORCE_EQ(input_shape.size(),
459+
input_spec.dist_attr().multi_dims_mapping().size(),
460+
common::errors::InvalidArgument(
461+
"The Tensor X's rank [%d] and X's "
462+
"dims_mapping size [%d] are not matched.",
463+
input_shape.size(),
464+
input_spec.dist_attr().multi_dims_mapping().size()));
465+
return InferFromDimTransCoShard(input_spec, input_shape, dim_trans);
466+
}
467+
314468
std::tuple<std::vector<std::vector<int64_t>>, std::vector<std::vector<int64_t>>>
315469
InferFromDimTrans(const DistMetaTensor& input,
316470
const std::vector<int64_t>& input_shape,
@@ -400,4 +554,104 @@ InferFromDimTrans(const DistMetaTensor& input,
400554
return {new_input_dims_mapping, out_dims_mapping};
401555
}
402556

557+
std::tuple<std::vector<std::vector<int64_t>>, std::vector<std::vector<int64_t>>>
558+
InferFromDimTransCoShard(
559+
const DistMetaTensor& input,
560+
const std::vector<int64_t>& input_shape,
561+
const std::vector<std::shared_ptr<DimTrans>>& dim_trans) {
562+
const std::vector<std::vector<int64_t>>& input_dims_mapping =
563+
input.dist_attr().multi_dims_mapping();
564+
const ProcessMesh& mesh = input.dist_attr().process_mesh();
565+
const std::vector<int64_t>& mesh_shape = mesh.shape();
566+
567+
std::set<int64_t> sharded_input_dims;
568+
for (int64_t i = 0, n = static_cast<int64_t>(input_dims_mapping.size());
569+
i < n;
570+
++i) {
571+
if (std::any_of(input_dims_mapping[i].begin(),
572+
input_dims_mapping[i].end(),
573+
[](int64_t dim) { return dim > -1; })) {
574+
sharded_input_dims.insert(i);
575+
}
576+
}
577+
int64_t ndim = static_cast<int64_t>(input_shape.size());
578+
int64_t nmesh = static_cast<int64_t>(mesh_shape.size());
579+
std::vector<std::vector<bool>> shardable(ndim,
580+
std::vector<bool>(nmesh, true));
581+
582+
std::set<int64_t> seen_input_dims;
583+
for (const std::shared_ptr<DimTrans>& trans : dim_trans) {
584+
GetUsedInputDim(trans, &seen_input_dims);
585+
}
586+
587+
for (int64_t idim = 0; idim < ndim; idim++) {
588+
bool seen = seen_input_dims.count(idim);
589+
if (!seen) {
590+
shardable[idim].assign(nmesh, seen);
591+
}
592+
}
593+
594+
// get the map from sharded input dimensions to output dimensions.
595+
// key is src dim, value is dst dim.
596+
std::vector<int64_t> dim_map_src2tgt(ndim, -1);
597+
std::unordered_map<int, std::vector<int>> dim_map_dst2src;
598+
for (int64_t i = 0, n = static_cast<int64_t>(dim_trans.size()); i < n; i++) {
599+
std::vector<std::shared_ptr<DimTrans>> dims =
600+
GetDimTransCoShard(dim_trans[i],
601+
input_shape,
602+
mesh_shape,
603+
input_dims_mapping,
604+
sharded_input_dims,
605+
&shardable,
606+
&seen_input_dims);
607+
for (auto dim : dims) {
608+
if (dim->type() == DimTrans::Type::INPUTDIM) {
609+
std::shared_ptr<InputDim> inputdim =
610+
std::dynamic_pointer_cast<InputDim>(dim);
611+
dim_map_src2tgt[inputdim->input_dim()] = i;
612+
dim_map_dst2src[i].push_back(inputdim->input_dim());
613+
}
614+
}
615+
}
616+
617+
std::vector<std::vector<int64_t>> out_dims_mapping(dim_trans.size());
618+
std::vector<std::vector<int64_t>> new_input_dims_mapping(
619+
input_dims_mapping.size());
620+
621+
// set output dims mapping with corresponding input dimensions.
622+
// if one input dimension is sharded on a unshardable mesh after
623+
// splitting, we need to make it replicated.
624+
for (int64_t i = 0; i < ndim; i++) {
625+
const auto& mesh_dims = input_dims_mapping[i];
626+
if (!std::all_of(mesh_dims.begin(),
627+
mesh_dims.end(),
628+
[](int64_t dim) { return dim >= 0; }) ||
629+
dim_map_src2tgt[i] == -1) {
630+
continue;
631+
}
632+
633+
bool is_unshardable = false;
634+
for (const auto& mesh_dim : mesh_dims) {
635+
if (mesh_dim >= 0 && !shardable[i][mesh_dim]) {
636+
is_unshardable = true;
637+
break;
638+
}
639+
}
640+
if (!is_unshardable) {
641+
int dst_dim = dim_map_src2tgt[i];
642+
const auto& src_dims = dim_map_dst2src[dst_dim];
643+
auto min_dim_it = std::min_element(src_dims.begin(), src_dims.end());
644+
int64_t min_dim = *min_dim_it;
645+
out_dims_mapping[dst_dim].insert(
646+
out_dims_mapping[dst_dim].end(), mesh_dims.begin(), mesh_dims.end());
647+
new_input_dims_mapping[min_dim].insert(
648+
new_input_dims_mapping[min_dim].end(),
649+
mesh_dims.begin(),
650+
mesh_dims.end());
651+
}
652+
}
653+
654+
return {new_input_dims_mapping, out_dims_mapping};
655+
}
656+
403657
} // namespace phi::distributed

paddle/phi/infermeta/spmd_rules/dim_trans.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,10 +158,21 @@ std::tuple<std::vector<std::vector<int64_t>>, std::vector<std::vector<int64_t>>>
158158
InferFromDimTrans(const DistMetaTensor& input_spec,
159159
const std::vector<std::shared_ptr<DimTrans>>& dim_trans);
160160

161+
std::tuple<std::vector<std::vector<int64_t>>, std::vector<std::vector<int64_t>>>
162+
InferFromDimTransCoShard(
163+
const DistMetaTensor& input_spec,
164+
const std::vector<std::shared_ptr<DimTrans>>& dim_trans);
165+
161166
std::tuple<std::vector<std::vector<int64_t>>, std::vector<std::vector<int64_t>>>
162167
InferFromDimTrans(const DistMetaTensor& input_spec,
163168
const std::vector<int64_t>& input_shape,
164169
const std::vector<std::shared_ptr<DimTrans>>& dim_trans);
165170

171+
std::tuple<std::vector<std::vector<int64_t>>, std::vector<std::vector<int64_t>>>
172+
InferFromDimTransCoShard(
173+
const DistMetaTensor& input_spec,
174+
const std::vector<int64_t>& input_shape,
175+
const std::vector<std::shared_ptr<DimTrans>>& dim_trans);
176+
166177
} // namespace distributed
167178
} // namespace phi

0 commit comments

Comments
 (0)