Skip to content

Commit 57d59ba

Browse files
authored
[Auto Parallel] Add co_shard spmd_rule for bmm (PaddlePaddle#75555)
1 parent 1f1b56d commit 57d59ba

File tree

7 files changed

+257
-1
lines changed

7 files changed

+257
-1
lines changed
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
/* Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/phi/infermeta/spmd_rules/bmm.h"
13+
14+
#include "glog/logging.h"
15+
16+
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
17+
#include "paddle/phi/core/distributed/auto_parallel/inferspmd_utils.h"
18+
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
19+
#include "paddle/phi/infermeta/spmd_rules/matmul.h"
20+
#include "paddle/phi/infermeta/spmd_rules/utils.h"
21+
22+
namespace phi {
23+
namespace distributed {
24+
25+
namespace {
26+
27+
std::vector<int64_t> CheckBmmTensorMeta(const DistMetaTensor& tensor,
28+
const char* tensor_name,
29+
const char* rule_name) {
30+
const auto shape = common::vectorize(tensor.dims());
31+
const auto& dims_mapping = tensor.dist_attr().multi_dims_mapping();
32+
33+
PADDLE_ENFORCE_EQ(shape.size(),
34+
3,
35+
common::errors::InvalidArgument(
36+
"%s expects %s to be a 3-D tensor, but it has rank %d.",
37+
rule_name,
38+
tensor_name,
39+
static_cast<int>(shape.size())));
40+
PADDLE_ENFORCE_EQ(
41+
dims_mapping.size(),
42+
shape.size(),
43+
common::errors::InvalidArgument(
44+
"%s expects dims_mapping length of %s (%d) to match its rank (%d).",
45+
rule_name,
46+
tensor_name,
47+
static_cast<int>(dims_mapping.size()),
48+
static_cast<int>(shape.size())));
49+
50+
return shape;
51+
}
52+
53+
inline void CheckDimEqual(int64_t lhs,
54+
int64_t rhs,
55+
const char* lhs_desc,
56+
const char* rhs_desc,
57+
const char* rule_name) {
58+
if (lhs != -1 && rhs != -1) {
59+
PADDLE_ENFORCE_EQ(lhs,
60+
rhs,
61+
common::errors::InvalidArgument(
62+
"%s expects %s (%d) to be equal to %s (%d).",
63+
rule_name,
64+
lhs_desc,
65+
lhs,
66+
rhs_desc,
67+
rhs));
68+
}
69+
}
70+
71+
} // namespace
72+
73+
SpmdInfo BmmInferSpmd(const DistMetaTensor& x, const DistMetaTensor& y) {
74+
const auto x_shape = CheckBmmTensorMeta(x, "Input(X)", "BmmInferSpmd");
75+
const auto y_shape = CheckBmmTensorMeta(y, "Input(Y)", "BmmInferSpmd");
76+
77+
CheckDimEqual(x_shape[2],
78+
y_shape[1],
79+
"the last dimension of Input(X)",
80+
"the second dimension of Input(Y)",
81+
"BmmInferSpmd");
82+
CheckDimEqual(x_shape[0],
83+
y_shape[0],
84+
"the batch dimension of Input(X)",
85+
"the batch dimension of Input(Y)",
86+
"BmmInferSpmd");
87+
88+
VLOG(6) << "BmmInferSpmd delegates to MatmulInferSpmd (trans_x=false, "
89+
"trans_y=false).";
90+
91+
return MatmulInferSpmd(x, y, false, false);
92+
}
93+
94+
SpmdInfo BmmGradInferSpmd(const DistMetaTensor& x,
95+
const DistMetaTensor& y,
96+
const DistMetaTensor& out_grad) {
97+
const auto x_shape = CheckBmmTensorMeta(x, "Input(X)", "BmmGradInferSpmd");
98+
const auto y_shape = CheckBmmTensorMeta(y, "Input(Y)", "BmmGradInferSpmd");
99+
const auto out_grad_shape =
100+
CheckBmmTensorMeta(out_grad, "Output@Grad", "BmmGradInferSpmd");
101+
102+
CheckDimEqual(x_shape[2],
103+
y_shape[1],
104+
"the last dimension of Input(X)",
105+
"the second dimension of Input(Y)",
106+
"BmmGradInferSpmd");
107+
CheckDimEqual(x_shape[0],
108+
y_shape[0],
109+
"the batch dimension of Input(X)",
110+
"the batch dimension of Input(Y)",
111+
"BmmGradInferSpmd");
112+
CheckDimEqual(x_shape[0],
113+
out_grad_shape[0],
114+
"the batch dimension of Input(X)",
115+
"the batch dimension of Output@Grad",
116+
"BmmGradInferSpmd");
117+
CheckDimEqual(x_shape[1],
118+
out_grad_shape[1],
119+
"the second dimension of Input(X)",
120+
"the second dimension of Output@Grad",
121+
"BmmGradInferSpmd");
122+
CheckDimEqual(y_shape[2],
123+
out_grad_shape[2],
124+
"the last dimension of Input(Y)",
125+
"the last dimension of Output@Grad",
126+
"BmmGradInferSpmd");
127+
128+
VLOG(6)
129+
<< "BmmGradInferSpmd delegates to MatmulGradInferSpmd (trans_x=false, "
130+
"trans_y=false).";
131+
132+
return MatmulGradInferSpmd(x, y, out_grad, false, false);
133+
}
134+
} // namespace distributed
135+
} // namespace phi
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
/* Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#pragma once
13+
14+
#include "paddle/phi/core/distributed/auto_parallel/dist_meta_tensor.h"
15+
#include "paddle/phi/core/distributed/type_defs.h"
16+
17+
namespace phi {
18+
namespace distributed {
19+
20+
SpmdInfo BmmInferSpmd(const DistMetaTensor& x, const DistMetaTensor& y);
21+
22+
SpmdInfo BmmGradInferSpmd(const DistMetaTensor& x,
23+
const DistMetaTensor& y,
24+
const DistMetaTensor& out_grad);
25+
26+
} // namespace distributed
27+
} // namespace phi

paddle/phi/infermeta/spmd_rules/rules.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,9 @@ PD_REGISTER_SPMD_RULE(matmul,
4040
PD_REGISTER_SPMD_RULE(matmul_v2, // static mode
4141
PD_INFER_SPMD(phi::distributed::MatmulInferSpmd),
4242
PD_INFER_SPMD(phi::distributed::MatmulInferSpmdReverse));
43-
43+
PD_REGISTER_SPMD_RULE(bmm,
44+
PD_INFER_SPMD(phi::distributed::BmmInferSpmd),
45+
PD_INFER_SPMD(phi::distributed::BmmGradInferSpmd));
4446
PD_REGISTER_SPMD_RULE(
4547
elementwise_unary,
4648
PD_INFER_SPMD(phi::distributed::ElementwiseUnaryInferSpmd),

paddle/phi/infermeta/spmd_rules/rules.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License. */
2020
#include "paddle/phi/infermeta/spmd_rules/argmin.h"
2121
#include "paddle/phi/infermeta/spmd_rules/argsort.h"
2222
#include "paddle/phi/infermeta/spmd_rules/batch_norm.h"
23+
#include "paddle/phi/infermeta/spmd_rules/bmm.h"
2324
#include "paddle/phi/infermeta/spmd_rules/c_embedding.h"
2425
#include "paddle/phi/infermeta/spmd_rules/c_softmax_with_cross_entropy.h"
2526
#include "paddle/phi/infermeta/spmd_rules/c_softmax_with_multi_label_cross_entropy.h"

paddle/phi/ops/yaml/backward.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,7 @@
323323
output : Tensor(x_grad), Tensor(y_grad)
324324
infer_meta :
325325
func : BmmGradInferMeta
326+
spmd_rule : BmmGradInferSpmd
326327
kernel :
327328
func : bmm_grad
328329
data_type : out_grad

paddle/phi/ops/yaml/ops.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,7 @@
773773
output : Tensor(out)
774774
infer_meta :
775775
func : BmmInferMeta
776+
spmd_rule: BmmInferSpmd
776777
kernel :
777778
func : bmm
778779
backward : bmm_grad

test/cpp/auto_parallel/matmul_co_shard_spmd_rule_test.cc

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include <set>
16+
#include "paddle/phi/infermeta/spmd_rules/bmm.h"
1617
#include "test/cpp/auto_parallel/spmd_rule_test_util.h"
1718

1819
namespace paddle {
@@ -411,6 +412,94 @@ TEST(MatmulGradInferSpmd, Ctor) {
411412
}
412413
}
413414

415+
TEST(BmmInferSpmd, CoShard) {
416+
std::vector<int64_t> mesh_shape = {2, 2, 2};
417+
std::vector<int64_t> process_ids = {0, 1, 2, 3, 4, 5, 6, 7};
418+
std::vector<std::string> dim_names = {"x", "y", "z"};
419+
ProcessMesh process_mesh(mesh_shape, process_ids, dim_names);
420+
421+
std::vector<int64_t> x_shape = {4, 16, 8};
422+
std::vector<std::vector<int64_t>> x_dims_mapping = {{0, 1}, {2}, {}};
423+
TensorDistAttr x_dist_attr;
424+
x_dist_attr.set_process_mesh(process_mesh);
425+
x_dist_attr.set_dims_mapping(x_dims_mapping);
426+
x_dist_attr.set_dynamic_dims(std::vector<bool>(x_shape.size(), false));
427+
phi::distributed::DistMetaTensor x(common::make_ddim(x_shape), x_dist_attr);
428+
429+
std::vector<int64_t> y_shape = {4, 8, 32};
430+
std::vector<std::vector<int64_t>> y_dims_mapping = {{0, 1}, {}, {}};
431+
TensorDistAttr y_dist_attr;
432+
y_dist_attr.set_process_mesh(process_mesh);
433+
y_dist_attr.set_dims_mapping(y_dims_mapping);
434+
y_dist_attr.set_dynamic_dims(std::vector<bool>(y_shape.size(), false));
435+
phi::distributed::DistMetaTensor y(common::make_ddim(y_shape), y_dist_attr);
436+
437+
auto bmm_spmd_info = phi::distributed::BmmInferSpmd(x, y);
438+
439+
ASSERT_EQ(bmm_spmd_info.first.size(), static_cast<size_t>(2));
440+
ASSERT_EQ(bmm_spmd_info.second.size(), static_cast<size_t>(1));
441+
442+
check_multi_dims_mapping(bmm_spmd_info.first[0], x_dims_mapping);
443+
EXPECT_FALSE(is_partial(bmm_spmd_info.first[0]));
444+
check_multi_dims_mapping(bmm_spmd_info.first[1], y_dims_mapping);
445+
EXPECT_FALSE(is_partial(bmm_spmd_info.first[1]));
446+
447+
const std::vector<std::vector<int64_t>> expected_out_dims_mapping = {
448+
{0, 1}, {2}, {}};
449+
check_multi_dims_mapping(bmm_spmd_info.second[0], expected_out_dims_mapping);
450+
EXPECT_FALSE(is_partial(bmm_spmd_info.second[0]));
451+
}
452+
453+
TEST(BmmGradInferSpmd, CoShard) {
454+
std::vector<int64_t> mesh_shape = {2, 2, 2};
455+
std::vector<int64_t> process_ids = {0, 1, 2, 3, 4, 5, 6, 7};
456+
std::vector<std::string> dim_names = {"x", "y", "z"};
457+
ProcessMesh process_mesh(mesh_shape, process_ids, dim_names);
458+
459+
std::vector<int64_t> x_shape = {4, 16, 8};
460+
std::vector<std::vector<int64_t>> x_dims_mapping = {{0, 1}, {2}, {}};
461+
TensorDistAttr x_dist_attr;
462+
x_dist_attr.set_process_mesh(process_mesh);
463+
x_dist_attr.set_dims_mapping(x_dims_mapping);
464+
x_dist_attr.set_dynamic_dims(std::vector<bool>(x_shape.size(), false));
465+
phi::distributed::DistMetaTensor x(common::make_ddim(x_shape), x_dist_attr);
466+
467+
std::vector<int64_t> y_shape = {4, 8, 32};
468+
std::vector<std::vector<int64_t>> y_dims_mapping = {{0, 1}, {}, {}};
469+
TensorDistAttr y_dist_attr;
470+
y_dist_attr.set_process_mesh(process_mesh);
471+
y_dist_attr.set_dims_mapping(y_dims_mapping);
472+
y_dist_attr.set_dynamic_dims(std::vector<bool>(y_shape.size(), false));
473+
phi::distributed::DistMetaTensor y(common::make_ddim(y_shape), y_dist_attr);
474+
475+
std::vector<int64_t> out_grad_shape = {4, 16, 32};
476+
std::vector<std::vector<int64_t>> out_grad_dims_mapping = {{0, 1}, {2}, {}};
477+
TensorDistAttr out_grad_dist_attr;
478+
out_grad_dist_attr.set_process_mesh(process_mesh);
479+
out_grad_dist_attr.set_dims_mapping(out_grad_dims_mapping);
480+
out_grad_dist_attr.set_dynamic_dims(
481+
std::vector<bool>(out_grad_shape.size(), false));
482+
phi::distributed::DistMetaTensor out_grad(common::make_ddim(out_grad_shape),
483+
out_grad_dist_attr);
484+
485+
auto bmm_grad_spmd_info = phi::distributed::BmmGradInferSpmd(x, y, out_grad);
486+
487+
ASSERT_EQ(bmm_grad_spmd_info.first.size(), static_cast<size_t>(3));
488+
ASSERT_EQ(bmm_grad_spmd_info.second.size(), static_cast<size_t>(2));
489+
490+
check_multi_dims_mapping(bmm_grad_spmd_info.first[0], x_dims_mapping);
491+
EXPECT_FALSE(is_partial(bmm_grad_spmd_info.first[0]));
492+
check_multi_dims_mapping(bmm_grad_spmd_info.first[1], y_dims_mapping);
493+
EXPECT_FALSE(is_partial(bmm_grad_spmd_info.first[1]));
494+
check_multi_dims_mapping(bmm_grad_spmd_info.first[2], out_grad_dims_mapping);
495+
EXPECT_FALSE(is_partial(bmm_grad_spmd_info.first[2]));
496+
497+
check_multi_dims_mapping(bmm_grad_spmd_info.second[0], x_dims_mapping);
498+
EXPECT_FALSE(is_partial(bmm_grad_spmd_info.second[0]));
499+
check_multi_dims_mapping(bmm_grad_spmd_info.second[1], y_dims_mapping);
500+
EXPECT_TRUE(is_partial(bmm_grad_spmd_info.second[1]));
501+
check_partial_dims(bmm_grad_spmd_info.second[1], {2});
502+
}
414503
} // namespace auto_parallel
415504
} // namespace distributed
416505
} // namespace paddle

0 commit comments

Comments
 (0)