@@ -161,7 +161,7 @@ std::shared_ptr<DimTrans> make_split(const std::shared_ptr<DimTrans> dim,
161
161
// map between from idx in shape to new_shape
162
162
std::vector<int64_t > idx_map (shape.size (), -1 );
163
163
for (int i = 0 , n = static_cast <int >(shape.size ()); i < n; ++i) {
164
- if (shape[id ] != 1 ) {
164
+ if (shape[i ] != 1 ) {
165
165
idx_map[i] = static_cast <int64_t >(new_shape.size ());
166
166
new_shape.emplace_back (shape[i]);
167
167
}
@@ -272,6 +272,139 @@ std::vector<std::shared_ptr<DimTrans>> GetDimTrans(
272
272
return ret_dim_trans;
273
273
}
274
274
275
+ std::vector<std::shared_ptr<DimTrans>> GetDimTransCoShard (
276
+ const std::shared_ptr<DimTrans> dim_trans,
277
+ const std::vector<int64_t >& input_shape,
278
+ const std::vector<int64_t >& mesh_shape,
279
+ const std::vector<std::vector<int64_t >>& input_dims_mapping,
280
+ const std::set<int64_t >& sharded_input_dims,
281
+ std::vector<std::vector<bool >>* shardable,
282
+ std::set<int64_t >* seen_dims) {
283
+ DimTrans::Type type = dim_trans->type ();
284
+ std::vector<std::shared_ptr<DimTrans>> ret_dim_trans;
285
+
286
+ if (type == DimTrans::Type::INPUTDIM) {
287
+ std::shared_ptr<InputDim> inputdim =
288
+ std::dynamic_pointer_cast<InputDim>(dim_trans);
289
+ int64_t dim = inputdim->input_dim ();
290
+ seen_dims->insert (dim);
291
+
292
+ if (sharded_input_dims.count (dim) > 0 ) {
293
+ ret_dim_trans.push_back (dim_trans);
294
+ }
295
+ } else if (type == DimTrans::Type::FLATTEN) {
296
+ std::shared_ptr<Flatten> flatten =
297
+ std::dynamic_pointer_cast<Flatten>(dim_trans);
298
+ const std::vector<std::shared_ptr<DimTrans>>& inputs = flatten->inputs ();
299
+
300
+ int64_t nmesh = (*shardable)[0 ].size (); // NOLINT
301
+ int64_t mesh_shape_prod = 1 ;
302
+
303
+ int last_shard_idx = -1 ;
304
+ int64_t first_shard_idx = -1 ;
305
+ int64_t first_sharded_shape = -1 ;
306
+
307
+ for (int i = 0 , n = static_cast <int >(inputs.size ()); i < n; ++i) {
308
+ std::shared_ptr<DimTrans> input = inputs[i];
309
+ if (input->type () != DimTrans::Type::INPUTDIM) {
310
+ break ;
311
+ }
312
+ std::shared_ptr<InputDim> inputdim =
313
+ std::dynamic_pointer_cast<InputDim>(input);
314
+ if (sharded_input_dims.count (inputdim->input_dim ()) > 0 ) {
315
+ if (first_shard_idx == -1 ) {
316
+ first_shard_idx = i;
317
+ first_sharded_shape = input_shape[inputdim->input_dim ()];
318
+ }
319
+ for (const auto & dim : input_dims_mapping[inputdim->input_dim ()]) {
320
+ mesh_shape_prod *= mesh_shape[dim];
321
+ }
322
+ if (first_sharded_shape % mesh_shape_prod == 0 ) {
323
+ ret_dim_trans.push_back (inputdim);
324
+ } else {
325
+ break ;
326
+ }
327
+ } else {
328
+ break ;
329
+ }
330
+ last_shard_idx = i;
331
+ }
332
+
333
+ for (int i = last_shard_idx + 1 , n = static_cast <int >(inputs.size ()); i < n;
334
+ i++) {
335
+ std::shared_ptr<DimTrans> input = inputs[i];
336
+ if (input->type () == DimTrans::Type::INPUTDIM) {
337
+ std::shared_ptr<InputDim> inputdim =
338
+ std::dynamic_pointer_cast<InputDim>(input);
339
+ (*shardable)[inputdim->input_dim ()].assign (nmesh, false );
340
+ }
341
+
342
+ GetDimTransCoShard (input,
343
+ input_shape,
344
+ mesh_shape,
345
+ input_dims_mapping,
346
+ sharded_input_dims,
347
+ shardable,
348
+ seen_dims);
349
+ }
350
+ } else if (type == DimTrans::Type::SPLIT) {
351
+ std::shared_ptr<Split> split = std::dynamic_pointer_cast<Split>(dim_trans);
352
+ std::vector<std::shared_ptr<DimTrans>> dims =
353
+ GetDimTransCoShard (split->input (),
354
+ input_shape,
355
+ mesh_shape,
356
+ input_dims_mapping,
357
+ sharded_input_dims,
358
+ shardable,
359
+ seen_dims);
360
+ int64_t ret_size = split->local_split_shape_value ();
361
+
362
+ if (split->split_id () == 0 ) {
363
+ int64_t mesh_shape_prod = 1 ;
364
+ int64_t first_shard_idx = -1 ;
365
+ int64_t first_sharded_shape = -1 ;
366
+ for (const auto & dim : dims) {
367
+ PADDLE_ENFORCE_EQ (dim->type (),
368
+ DimTrans::Type::INPUTDIM,
369
+ common::errors::InvalidArgument (
370
+ " The returned dim_trans must be INPUTDIM." ));
371
+ std::shared_ptr<InputDim> inputdim =
372
+ std::dynamic_pointer_cast<InputDim>(dim);
373
+ int64_t nmesh = static_cast <int64_t >(mesh_shape.size ());
374
+ int64_t input_axis = inputdim->input_dim ();
375
+
376
+ // Check whether the sharded dim can be sharded on
377
+ // each mesh dimension. The dimension should be
378
+ // divisible by the mesh size that it is sharded on
379
+ for (int64_t imesh = 0 ; imesh < nmesh; imesh++) {
380
+ (*shardable)[input_axis][imesh] = (ret_size % mesh_shape[imesh] == 0 );
381
+ }
382
+
383
+ if (first_shard_idx == -1 ) {
384
+ first_shard_idx = input_axis;
385
+ first_sharded_shape = input_shape[input_axis];
386
+ }
387
+
388
+ if (sharded_input_dims.count (input_axis) > 0 ) {
389
+ for (const auto & dim : input_dims_mapping[input_axis]) {
390
+ mesh_shape_prod *= mesh_shape[dim];
391
+ }
392
+ if ((ret_size % mesh_shape_prod == 0 ) &&
393
+ (first_sharded_shape % mesh_shape_prod == 0 )) {
394
+ ret_dim_trans.push_back (dim);
395
+ } else {
396
+ break ;
397
+ }
398
+ } else {
399
+ break ;
400
+ }
401
+ }
402
+ }
403
+ } else if (type == DimTrans::Type::SINGLETON) {
404
+ }
405
+ return ret_dim_trans;
406
+ }
407
+
275
408
void GetUsedInputDim (const std::shared_ptr<DimTrans> dim_trans,
276
409
std::set<int64_t >* seen_dims) {
277
410
if (dim_trans->type () == DimTrans::Type::INPUTDIM) {
@@ -311,6 +444,27 @@ InferFromDimTrans(const DistMetaTensor& input_spec,
311
444
return InferFromDimTrans (input_spec, input_shape, dim_trans);
312
445
}
313
446
447
+ std::tuple<std::vector<std::vector<int64_t >>, std::vector<std::vector<int64_t >>>
448
+ InferFromDimTransCoShard (
449
+ const DistMetaTensor& input_spec,
450
+ const std::vector<std::shared_ptr<DimTrans>>& dim_trans) {
451
+ auto input_shape = phi::vectorize (input_spec.dims ());
452
+ // deal with reshape xshape in dynamic
453
+ if (input_shape[0 ] == 0 &&
454
+ input_shape.size () !=
455
+ input_spec.dist_attr ().multi_dims_mapping ().size ()) {
456
+ input_shape.erase (input_shape.begin ());
457
+ }
458
+ PADDLE_ENFORCE_EQ (input_shape.size (),
459
+ input_spec.dist_attr ().multi_dims_mapping ().size (),
460
+ common::errors::InvalidArgument (
461
+ " The Tensor X's rank [%d] and X's "
462
+ " dims_mapping size [%d] are not matched." ,
463
+ input_shape.size (),
464
+ input_spec.dist_attr ().multi_dims_mapping ().size ()));
465
+ return InferFromDimTransCoShard (input_spec, input_shape, dim_trans);
466
+ }
467
+
314
468
std::tuple<std::vector<std::vector<int64_t >>, std::vector<std::vector<int64_t >>>
315
469
InferFromDimTrans (const DistMetaTensor& input,
316
470
const std::vector<int64_t >& input_shape,
@@ -400,4 +554,104 @@ InferFromDimTrans(const DistMetaTensor& input,
400
554
return {new_input_dims_mapping, out_dims_mapping};
401
555
}
402
556
557
+ std::tuple<std::vector<std::vector<int64_t >>, std::vector<std::vector<int64_t >>>
558
+ InferFromDimTransCoShard (
559
+ const DistMetaTensor& input,
560
+ const std::vector<int64_t >& input_shape,
561
+ const std::vector<std::shared_ptr<DimTrans>>& dim_trans) {
562
+ const std::vector<std::vector<int64_t >>& input_dims_mapping =
563
+ input.dist_attr ().multi_dims_mapping ();
564
+ const ProcessMesh& mesh = input.dist_attr ().process_mesh ();
565
+ const std::vector<int64_t >& mesh_shape = mesh.shape ();
566
+
567
+ std::set<int64_t > sharded_input_dims;
568
+ for (int64_t i = 0 , n = static_cast <int64_t >(input_dims_mapping.size ());
569
+ i < n;
570
+ ++i) {
571
+ if (std::any_of (input_dims_mapping[i].begin (),
572
+ input_dims_mapping[i].end (),
573
+ [](int64_t dim) { return dim > -1 ; })) {
574
+ sharded_input_dims.insert (i);
575
+ }
576
+ }
577
+ int64_t ndim = static_cast <int64_t >(input_shape.size ());
578
+ int64_t nmesh = static_cast <int64_t >(mesh_shape.size ());
579
+ std::vector<std::vector<bool >> shardable (ndim,
580
+ std::vector<bool >(nmesh, true ));
581
+
582
+ std::set<int64_t > seen_input_dims;
583
+ for (const std::shared_ptr<DimTrans>& trans : dim_trans) {
584
+ GetUsedInputDim (trans, &seen_input_dims);
585
+ }
586
+
587
+ for (int64_t idim = 0 ; idim < ndim; idim++) {
588
+ bool seen = seen_input_dims.count (idim);
589
+ if (!seen) {
590
+ shardable[idim].assign (nmesh, seen);
591
+ }
592
+ }
593
+
594
+ // get the map from sharded input dimensions to output dimensions.
595
+ // key is src dim, value is dst dim.
596
+ std::vector<int64_t > dim_map_src2tgt (ndim, -1 );
597
+ std::unordered_map<int , std::vector<int >> dim_map_dst2src;
598
+ for (int64_t i = 0 , n = static_cast <int64_t >(dim_trans.size ()); i < n; i++) {
599
+ std::vector<std::shared_ptr<DimTrans>> dims =
600
+ GetDimTransCoShard (dim_trans[i],
601
+ input_shape,
602
+ mesh_shape,
603
+ input_dims_mapping,
604
+ sharded_input_dims,
605
+ &shardable,
606
+ &seen_input_dims);
607
+ for (auto dim : dims) {
608
+ if (dim->type () == DimTrans::Type::INPUTDIM) {
609
+ std::shared_ptr<InputDim> inputdim =
610
+ std::dynamic_pointer_cast<InputDim>(dim);
611
+ dim_map_src2tgt[inputdim->input_dim ()] = i;
612
+ dim_map_dst2src[i].push_back (inputdim->input_dim ());
613
+ }
614
+ }
615
+ }
616
+
617
+ std::vector<std::vector<int64_t >> out_dims_mapping (dim_trans.size ());
618
+ std::vector<std::vector<int64_t >> new_input_dims_mapping (
619
+ input_dims_mapping.size ());
620
+
621
+ // set output dims mapping with corresponding input dimensions.
622
+ // if one input dimension is sharded on a unshardable mesh after
623
+ // splitting, we need to make it replicated.
624
+ for (int64_t i = 0 ; i < ndim; i++) {
625
+ const auto & mesh_dims = input_dims_mapping[i];
626
+ if (!std::all_of (mesh_dims.begin (),
627
+ mesh_dims.end (),
628
+ [](int64_t dim) { return dim >= 0 ; }) ||
629
+ dim_map_src2tgt[i] == -1 ) {
630
+ continue ;
631
+ }
632
+
633
+ bool is_unshardable = false ;
634
+ for (const auto & mesh_dim : mesh_dims) {
635
+ if (mesh_dim >= 0 && !shardable[i][mesh_dim]) {
636
+ is_unshardable = true ;
637
+ break ;
638
+ }
639
+ }
640
+ if (!is_unshardable) {
641
+ int dst_dim = dim_map_src2tgt[i];
642
+ const auto & src_dims = dim_map_dst2src[dst_dim];
643
+ auto min_dim_it = std::min_element (src_dims.begin (), src_dims.end ());
644
+ int64_t min_dim = *min_dim_it;
645
+ out_dims_mapping[dst_dim].insert (
646
+ out_dims_mapping[dst_dim].end (), mesh_dims.begin (), mesh_dims.end ());
647
+ new_input_dims_mapping[min_dim].insert (
648
+ new_input_dims_mapping[min_dim].end (),
649
+ mesh_dims.begin (),
650
+ mesh_dims.end ());
651
+ }
652
+ }
653
+
654
+ return {new_input_dims_mapping, out_dims_mapping};
655
+ }
656
+
403
657
} // namespace phi::distributed
0 commit comments