@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313limitations under the License. */
1414
1515#include < set>
16+ #include " paddle/phi/infermeta/spmd_rules/bmm.h"
1617#include " test/cpp/auto_parallel/spmd_rule_test_util.h"
1718
1819namespace paddle {
@@ -411,6 +412,94 @@ TEST(MatmulGradInferSpmd, Ctor) {
411412 }
412413}
413414
415+ TEST (BmmInferSpmd, CoShard) {
416+ std::vector<int64_t > mesh_shape = {2 , 2 , 2 };
417+ std::vector<int64_t > process_ids = {0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 };
418+ std::vector<std::string> dim_names = {" x" , " y" , " z" };
419+ ProcessMesh process_mesh (mesh_shape, process_ids, dim_names);
420+
421+ std::vector<int64_t > x_shape = {4 , 16 , 8 };
422+ std::vector<std::vector<int64_t >> x_dims_mapping = {{0 , 1 }, {2 }, {}};
423+ TensorDistAttr x_dist_attr;
424+ x_dist_attr.set_process_mesh (process_mesh);
425+ x_dist_attr.set_dims_mapping (x_dims_mapping);
426+ x_dist_attr.set_dynamic_dims (std::vector<bool >(x_shape.size (), false ));
427+ phi::distributed::DistMetaTensor x (common::make_ddim (x_shape), x_dist_attr);
428+
429+ std::vector<int64_t > y_shape = {4 , 8 , 32 };
430+ std::vector<std::vector<int64_t >> y_dims_mapping = {{0 , 1 }, {}, {}};
431+ TensorDistAttr y_dist_attr;
432+ y_dist_attr.set_process_mesh (process_mesh);
433+ y_dist_attr.set_dims_mapping (y_dims_mapping);
434+ y_dist_attr.set_dynamic_dims (std::vector<bool >(y_shape.size (), false ));
435+ phi::distributed::DistMetaTensor y (common::make_ddim (y_shape), y_dist_attr);
436+
437+ auto bmm_spmd_info = phi::distributed::BmmInferSpmd (x, y);
438+
439+ ASSERT_EQ (bmm_spmd_info.first .size (), static_cast <size_t >(2 ));
440+ ASSERT_EQ (bmm_spmd_info.second .size (), static_cast <size_t >(1 ));
441+
442+ check_multi_dims_mapping (bmm_spmd_info.first [0 ], x_dims_mapping);
443+ EXPECT_FALSE (is_partial (bmm_spmd_info.first [0 ]));
444+ check_multi_dims_mapping (bmm_spmd_info.first [1 ], y_dims_mapping);
445+ EXPECT_FALSE (is_partial (bmm_spmd_info.first [1 ]));
446+
447+ const std::vector<std::vector<int64_t >> expected_out_dims_mapping = {
448+ {0 , 1 }, {2 }, {}};
449+ check_multi_dims_mapping (bmm_spmd_info.second [0 ], expected_out_dims_mapping);
450+ EXPECT_FALSE (is_partial (bmm_spmd_info.second [0 ]));
451+ }
452+
453+ TEST (BmmGradInferSpmd, CoShard) {
454+ std::vector<int64_t > mesh_shape = {2 , 2 , 2 };
455+ std::vector<int64_t > process_ids = {0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 };
456+ std::vector<std::string> dim_names = {" x" , " y" , " z" };
457+ ProcessMesh process_mesh (mesh_shape, process_ids, dim_names);
458+
459+ std::vector<int64_t > x_shape = {4 , 16 , 8 };
460+ std::vector<std::vector<int64_t >> x_dims_mapping = {{0 , 1 }, {2 }, {}};
461+ TensorDistAttr x_dist_attr;
462+ x_dist_attr.set_process_mesh (process_mesh);
463+ x_dist_attr.set_dims_mapping (x_dims_mapping);
464+ x_dist_attr.set_dynamic_dims (std::vector<bool >(x_shape.size (), false ));
465+ phi::distributed::DistMetaTensor x (common::make_ddim (x_shape), x_dist_attr);
466+
467+ std::vector<int64_t > y_shape = {4 , 8 , 32 };
468+ std::vector<std::vector<int64_t >> y_dims_mapping = {{0 , 1 }, {}, {}};
469+ TensorDistAttr y_dist_attr;
470+ y_dist_attr.set_process_mesh (process_mesh);
471+ y_dist_attr.set_dims_mapping (y_dims_mapping);
472+ y_dist_attr.set_dynamic_dims (std::vector<bool >(y_shape.size (), false ));
473+ phi::distributed::DistMetaTensor y (common::make_ddim (y_shape), y_dist_attr);
474+
475+ std::vector<int64_t > out_grad_shape = {4 , 16 , 32 };
476+ std::vector<std::vector<int64_t >> out_grad_dims_mapping = {{0 , 1 }, {2 }, {}};
477+ TensorDistAttr out_grad_dist_attr;
478+ out_grad_dist_attr.set_process_mesh (process_mesh);
479+ out_grad_dist_attr.set_dims_mapping (out_grad_dims_mapping);
480+ out_grad_dist_attr.set_dynamic_dims (
481+ std::vector<bool >(out_grad_shape.size (), false ));
482+ phi::distributed::DistMetaTensor out_grad (common::make_ddim (out_grad_shape),
483+ out_grad_dist_attr);
484+
485+ auto bmm_grad_spmd_info = phi::distributed::BmmGradInferSpmd (x, y, out_grad);
486+
487+ ASSERT_EQ (bmm_grad_spmd_info.first .size (), static_cast <size_t >(3 ));
488+ ASSERT_EQ (bmm_grad_spmd_info.second .size (), static_cast <size_t >(2 ));
489+
490+ check_multi_dims_mapping (bmm_grad_spmd_info.first [0 ], x_dims_mapping);
491+ EXPECT_FALSE (is_partial (bmm_grad_spmd_info.first [0 ]));
492+ check_multi_dims_mapping (bmm_grad_spmd_info.first [1 ], y_dims_mapping);
493+ EXPECT_FALSE (is_partial (bmm_grad_spmd_info.first [1 ]));
494+ check_multi_dims_mapping (bmm_grad_spmd_info.first [2 ], out_grad_dims_mapping);
495+ EXPECT_FALSE (is_partial (bmm_grad_spmd_info.first [2 ]));
496+
497+ check_multi_dims_mapping (bmm_grad_spmd_info.second [0 ], x_dims_mapping);
498+ EXPECT_FALSE (is_partial (bmm_grad_spmd_info.second [0 ]));
499+ check_multi_dims_mapping (bmm_grad_spmd_info.second [1 ], y_dims_mapping);
500+ EXPECT_TRUE (is_partial (bmm_grad_spmd_info.second [1 ]));
501+ check_partial_dims (bmm_grad_spmd_info.second [1 ], {2 });
502+ }
414503} // namespace auto_parallel
415504} // namespace distributed
416505} // namespace paddle
0 commit comments