Skip to content

Commit 01666a6

Browse files
authored
[Auto Parallel] Replace paddle::get usage in spmd_rules dir (#74543)
* Replace paddle::get usage in spmd_rules dir * Fix bug
1 parent f1ae790 commit 01666a6

File tree

5 files changed

+30
-19
lines changed

5 files changed

+30
-19
lines changed

paddle/phi/infermeta/spmd_rules/fused_dropout_add.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ SpmdInfo FusedDropoutAddSpmdBase(const DistMetaTensor& x,
3636
VLOG(4) << "x dist_attr: [" << x.dist_attr().to_string() << "]";
3737
VLOG(4) << "y dist_attr: [" << y.dist_attr().to_string() << "]";
3838
VLOG(4) << "out dist_attr: ["
39-
<< paddle::get<0>(out_info.second[0]).to_string() << "]";
39+
<< PADDLE_GET_CONST(TensorDistAttr, out_info.second[0]).to_string()
40+
<< "]";
4041
VLOG(4) << "seed_offset dist_attr: [" << seed_offset_dist_attr.to_string()
4142
<< "]";
4243
return {{x.dist_attr(), y.dist_attr()},
@@ -51,9 +52,11 @@ SpmdInfo FusedDropoutAddSpmdReverseBase(const DistMetaTensor& x,
5152

5253
VLOG(4) << "out dist_attr: [" << out.dist_attr().to_string() << "]";
5354
VLOG(4) << "x dist_attr: ["
54-
<< paddle::get<0>(reverse_info.first[0]).to_string() << "]";
55+
<< PADDLE_GET_CONST(TensorDistAttr, reverse_info.first[0]).to_string()
56+
<< "]";
5557
VLOG(4) << "y dist_attr: ["
56-
<< paddle::get<0>(reverse_info.first[1]).to_string() << "]";
58+
<< PADDLE_GET_CONST(TensorDistAttr, reverse_info.first[1]).to_string()
59+
<< "]";
5760
return {reverse_info.first,
5861
{reverse_info.second[0], seed_offset.dist_attr()}};
5962
}

paddle/phi/infermeta/spmd_rules/index_select.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,10 +98,12 @@ SpmdInfo IndexSelectGradInferSpmd(const DistMetaTensor& x,
9898
out_grad_ndim));
9999
// now use forward spmd rule to reduce complexity without actual cost eval.
100100
SpmdInfo fwd_spmd_info = IndexSelectInferSpmd(x, index, axis);
101-
TensorDistAttr x_dist_attr_dst = paddle::get<0>(fwd_spmd_info.first[0]);
102-
TensorDistAttr index_dist_attr_dst = paddle::get<0>(fwd_spmd_info.first[1]);
103-
TensorDistAttr out_grad_dist_attr_dst =
104-
paddle::get<0>(fwd_spmd_info.second[0]);
101+
const TensorDistAttr& x_dist_attr_dst =
102+
PADDLE_GET_CONST(TensorDistAttr, fwd_spmd_info.first[0]);
103+
const TensorDistAttr& index_dist_attr_dst =
104+
PADDLE_GET_CONST(TensorDistAttr, fwd_spmd_info.first[1]);
105+
const TensorDistAttr& out_grad_dist_attr_dst =
106+
PADDLE_GET_CONST(TensorDistAttr, fwd_spmd_info.second[0]);
105107

106108
TensorDistAttr x_grad_dist_attr_dst = x_dist_attr_dst;
107109
x_grad_dist_attr_dst.clean_partial_status();

paddle/phi/infermeta/spmd_rules/matmul.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ SpmdInfo MatmulGradInferSpmd(const DistMetaTensor& x_,
291291
bool trans_y) {
292292
DistMetaTensor x = x_, y = y_;
293293
auto get_attr = [](const ArgDistAttr& attr) -> const TensorDistAttr& {
294-
return paddle::get<TensorDistAttr>(attr);
294+
return PADDLE_GET_CONST(TensorDistAttr, attr);
295295
};
296296

297297
auto confirm_dist_attr_same_fn = [&](const ArgDistAttr& x_dist_attr,

paddle/phi/infermeta/spmd_rules/replicated.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -164,15 +164,17 @@ SpmdInfo ReplicatedInferDynamic(
164164

165165
for (int64_t i = 0; i < ninputs; i++) {
166166
if (paddle::holds_alternative<const DistMetaTensor*>(inputs[i])) {
167-
auto dist_meta_tensor_ptr = paddle::get<0>(inputs[i]);
168-
auto& dist_meta_tensor = *dist_meta_tensor_ptr;
167+
const auto* dist_meta_tensor_ptr =
168+
PADDLE_GET_CONST(const DistMetaTensor*, inputs[i]);
169+
const auto& dist_meta_tensor = *dist_meta_tensor_ptr;
169170
auto dist_attr_dst = build_tensor_dist_attr(dist_meta_tensor);
170171
VLOG(4) << "input " << i << ": dist attr: " << dist_attr_dst.to_string();
171172
spmd_info.first.emplace_back(dist_attr_dst);
172173
} else {
173174
std::vector<phi::distributed::TensorDistAttr> list_dist_attr;
174-
auto dist_meta_tensors_ptr = paddle::get<1>(inputs[i]);
175-
auto& dist_meta_tensors = *dist_meta_tensors_ptr;
175+
const auto* dist_meta_tensors_ptr =
176+
PADDLE_GET_CONST(const std::vector<DistMetaTensor>*, inputs[i]);
177+
const auto& dist_meta_tensors = *dist_meta_tensors_ptr;
176178
for (const auto& dist_meta_tensor : dist_meta_tensors) {
177179
auto dist_attr_dst = build_tensor_dist_attr(dist_meta_tensor);
178180
VLOG(4) << "input " << i

test/cpp/auto_parallel/spmd_rule_test_util.cc

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,25 @@ const std::vector<int64_t>& get_dims_mapping(
2222
const phi::distributed::ArgDistAttr& dist_attr) {
2323
EXPECT_TRUE(
2424
paddle::holds_alternative<phi::distributed::TensorDistAttr>(dist_attr));
25-
const auto& tensor_attr = paddle::get<0>(dist_attr);
25+
const auto& tensor_attr =
26+
PADDLE_GET_CONST(phi::distributed::TensorDistAttr, dist_attr);
2627
return tensor_attr.dims_mapping();
2728
}
2829

2930
bool is_partial(const phi::distributed::ArgDistAttr& dist_attr) {
3031
EXPECT_TRUE(
3132
paddle::holds_alternative<phi::distributed::TensorDistAttr>(dist_attr));
32-
const auto& tensor_attr = paddle::get<0>(dist_attr);
33+
const auto& tensor_attr =
34+
PADDLE_GET_CONST(phi::distributed::TensorDistAttr, dist_attr);
3335
return tensor_attr.is_partial();
3436
}
3537

3638
const std::set<int64_t> get_partial_dims(
3739
const phi::distributed::ArgDistAttr& dist_attr) {
3840
EXPECT_TRUE(
3941
paddle::holds_alternative<phi::distributed::TensorDistAttr>(dist_attr));
40-
const auto& tensor_attr = paddle::get<0>(dist_attr);
42+
const auto& tensor_attr =
43+
PADDLE_GET_CONST(phi::distributed::TensorDistAttr, dist_attr);
4144
return tensor_attr.partial_dims();
4245
}
4346

@@ -74,7 +77,8 @@ void check_empty_dist_attr(const phi::distributed::ArgDistAttr& dist_attr,
7477
EXPECT_TRUE(
7578
paddle::holds_alternative<phi::distributed::TensorDistAttr>(dist_attr))
7679
<< line;
77-
EXPECT_EQ(paddle::get<0>(dist_attr), phi::distributed::TensorDistAttr());
80+
EXPECT_EQ(PADDLE_GET_CONST(phi::distributed::TensorDistAttr, dist_attr),
81+
phi::distributed::TensorDistAttr());
7882
}
7983

8084
void check_partial_dims(const phi::distributed::ArgDistAttr& dist_attr,
@@ -89,23 +93,23 @@ void check_partial_dims(const phi::distributed::ArgDistAttr& dist_attr,
8993
void clean_partial_status(phi::distributed::ArgDistAttr* dist_attr) {
9094
EXPECT_TRUE(
9195
paddle::holds_alternative<phi::distributed::TensorDistAttr>(*dist_attr));
92-
auto& tensor_attr = paddle::get<0>(*dist_attr);
96+
auto& tensor_attr = PADDLE_GET(phi::distributed::TensorDistAttr, *dist_attr);
9397
tensor_attr.clean_partial_status();
9498
}
9599

96100
void clean_partial_dims(phi::distributed::ArgDistAttr* dist_attr,
97101
std::vector<int64_t> dims) {
98102
EXPECT_TRUE(
99103
paddle::holds_alternative<phi::distributed::TensorDistAttr>(*dist_attr));
100-
auto& tensor_attr = paddle::get<0>(*dist_attr);
104+
auto& tensor_attr = PADDLE_GET(phi::distributed::TensorDistAttr, *dist_attr);
101105
tensor_attr.clean_partial_dims(dims);
102106
}
103107

104108
void set_partial_status(phi::distributed::ArgDistAttr* dist_attr,
105109
std::vector<int64_t> dims) {
106110
EXPECT_TRUE(
107111
paddle::holds_alternative<phi::distributed::TensorDistAttr>(*dist_attr));
108-
auto& tensor_attr = paddle::get<0>(*dist_attr);
112+
auto& tensor_attr = PADDLE_GET(phi::distributed::TensorDistAttr, *dist_attr);
109113
tensor_attr.set_partial_status(dims);
110114
}
111115

0 commit comments

Comments
 (0)