Skip to content

Commit a28e6f6

Browse files
authored
reshard r to p (#56833)
1 parent 413ca98 commit a28e6f6

File tree

14 files changed

+319
-22
lines changed

14 files changed

+319
-22
lines changed

paddle/fluid/pybind/auto_parallel_py.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/common.h"
3333
#include "paddle/fluid/distributed/auto_parallel/spmd_rules/dist_tensor_spec.h"
3434
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
35+
#include "paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.h"
3536
#include "paddle/phi/core/distributed/auto_parallel/r_to_s_reshard_function.h"
3637
#include "paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.h"
3738

@@ -157,6 +158,10 @@ void BindAutoParallel(py::module *m) {
157158
*m, "SToRReshardFunction", ReshardFunction)
158159
.def(py::init<>());
159160

161+
py::class_<phi::distributed::RToPReshardFunction>(
162+
*m, "RToPReshardFunction", ReshardFunction)
163+
.def(py::init<>());
164+
160165
py::class_<ProcessMesh>(*m, "ProcessMesh")
161166
.def(py::init<>())
162167
.def(py::init<const std::vector<int64_t> &,
@@ -338,6 +343,10 @@ void BindAutoParallel(py::module *m) {
338343
.def("_is_partial", &TensorDistAttr::is_partial)
339344
.def("_partial_dims", &TensorDistAttr::partial_dims)
340345
.def("_clean_partial_dims", &TensorDistAttr::clean_partial_dims)
346+
.def("_set_partial_dims",
347+
[](TensorDistAttr &self, const std::vector<int64_t> &dims) {
348+
self.set_partial_status(dims);
349+
})
341350
.def("_clean_partial_status", &TensorDistAttr::clean_partial_status);
342351

343352
py::class_<SPMDRuleBase>(*m, "SPMDRuleBase")

paddle/fluid/pybind/eager_method.cc

Lines changed: 58 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ typedef SSIZE_T ssize_t;
6161
#include "paddle/phi/api/lib/data_transform.h"
6262
#include "paddle/phi/core/ddim.h"
6363
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
64+
#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h"
65+
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
6466
#include "paddle/phi/core/flags.h"
6567
#include "paddle/phi/core/tensor_utils.h"
6668
#include "paddle/phi/kernels/funcs/math_function.h"
@@ -99,6 +101,30 @@ Py_ssize_t GetSliceIndexFromPyObject(PyObject* obj) {
99101
}
100102
}
101103

104+
namespace {
105+
#ifdef PADDLE_WITH_DISTRIBUTE
106+
phi::DenseTensor ReshardXToReplicated(
107+
phi::distributed::DistTensor* dist_tensor) {
108+
if (!phi::distributed::IsDimsMappingReplicated(
109+
dist_tensor->dist_attr().dims_mapping())) {
110+
phi::distributed::TensorDistAttr dist_attr(dist_tensor->dist_attr());
111+
std::vector<int64_t> dims_mapping(dist_tensor->dims().size(), -1);
112+
dist_attr.set_dims_mapping(dims_mapping);
113+
114+
// reshard to replicate dist tensor
115+
auto* func =
116+
phi::distributed::ChooseProperReshardFunction(*dist_tensor, dist_attr);
117+
auto* dev_ctx =
118+
phi::DeviceContextPool::Instance().Get(dist_tensor->place());
119+
auto out_tensor = func->Eval(dev_ctx, *dist_tensor, dist_attr);
120+
return out_tensor->value();
121+
} else {
122+
return dist_tensor->value();
123+
}
124+
}
125+
#endif
126+
} // namespace
127+
102128
PyDoc_STRVAR(tensor_method_numpy__doc__, // NOLINT
103129
R"DOC(numpy($self, /)
104130
--
@@ -145,15 +171,6 @@ static PyObject* tensor_method_numpy(TensorObject* self,
145171
return array;
146172
}
147173
auto tensor_dims = self->tensor.shape();
148-
#ifdef PADDLE_WITH_DISTRIBUTE
149-
// Now the DistTensor's numpy() return the local tensor value
150-
if (self->tensor.is_dist_tensor()) {
151-
tensor_dims = phi::vectorize(
152-
static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get())
153-
->value()
154-
.dims());
155-
}
156-
#endif
157174
auto numpy_dtype = TensorDtype2NumpyDtype(self->tensor.type());
158175
auto sizeof_dtype = phi::SizeOf(self->tensor.type());
159176
Py_intptr_t py_dims[paddle::framework::DDim::kMaxRank]; // NOLINT
@@ -258,12 +275,11 @@ static PyObject* tensor_method_numpy(TensorObject* self,
258275
dense_tensor->Holder()->size());
259276
} else if (self->tensor.is_dist_tensor()) {
260277
#ifdef PADDLE_WITH_DISTRIBUTE
261-
// TODO(chenweihang): deal with DistTensor as local DenseTensor now,
262-
// if the local DenseTensor is shard or partial, do gather or reduce?
263278
VLOG(6) << "Getting DistTensor's numpy value";
264279
auto* dist_tensor =
265280
static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get());
266-
auto& dense_tensor = dist_tensor->value();
281+
auto dense_tensor = ReshardXToReplicated(dist_tensor);
282+
267283
cpu_tensor.set_meta(dense_tensor.meta());
268284
// deep copy
269285
auto tmp_allocation_ptr =
@@ -330,7 +346,8 @@ static PyObject* tensor_method_numpy(TensorObject* self,
330346
VLOG(6) << "Getting DistTensor's numpy value";
331347
auto* dist_tensor =
332348
static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get());
333-
auto& dense_tensor = dist_tensor->value();
349+
auto dense_tensor = ReshardXToReplicated(dist_tensor);
350+
334351
cpu_tensor.set_meta(dense_tensor.meta());
335352
auto tmp_allocation_ptr =
336353
memory::Alloc(cpu_place, dense_tensor.Holder()->size());
@@ -2680,6 +2697,30 @@ static PyObject* tensor__grad_value(TensorObject* self,
26802697
EAGER_CATCH_AND_THROW_RETURN_NULL
26812698
}
26822699

2700+
static PyObject* tensor__local_value(TensorObject* self,
2701+
PyObject* args,
2702+
PyObject* kwargs) {
2703+
EAGER_TRY
2704+
if (self->tensor.is_dist_tensor()) {
2705+
#ifdef PADDLE_WITH_DISTRIBUTE
2706+
phi::distributed::DistTensor* dist_tensor =
2707+
static_cast<phi::distributed::DistTensor*>(self->tensor.impl().get());
2708+
paddle::Tensor result(
2709+
std::make_shared<phi::DenseTensor>(dist_tensor->value()));
2710+
return ToPyObject(result);
2711+
#else
2712+
PADDLE_THROW(platform::errors::Unavailable(
2713+
"The `_local_value` method of (Dist)Tensor is not supported "
2714+
"in the current PaddlePaddle, please recompile and install "
2715+
"PaddlePaddle "
2716+
"with the option of `WITH_DISTRIBUTE=ON`."));
2717+
#endif
2718+
} else {
2719+
RETURN_PY_NONE
2720+
}
2721+
EAGER_CATCH_AND_THROW_RETURN_NULL
2722+
}
2723+
26832724
static PyObject* tensor__unset_fake_empty(TensorObject* self,
26842725
PyObject* args,
26852726
PyObject* kwargs) {
@@ -3131,6 +3172,10 @@ PyMethodDef variable_methods[] = { // NOLINT
31313172
(PyCFunction)(void (*)())tensor__grad_value,
31323173
METH_VARARGS | METH_KEYWORDS,
31333174
nullptr},
3175+
{"_local_value",
3176+
(PyCFunction)(void (*)())tensor__local_value,
3177+
METH_VARARGS | METH_KEYWORDS,
3178+
nullptr},
31343179
{"_unset_fake_empty",
31353180
(PyCFunction)(void (*)())tensor__unset_fake_empty,
31363181
METH_VARARGS | METH_KEYWORDS,

paddle/phi/core/distributed/auto_parallel/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,5 @@ collect_srcs(
1313
inferspmd_utils.cc
1414
reshard_function.cc
1515
r_to_s_reshard_function.cc
16-
s_to_r_reshard_function.cc)
16+
s_to_r_reshard_function.cc
17+
r_to_p_reshard_function.cc)

paddle/phi/core/distributed/auto_parallel/dist_attr.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ bool TensorDistAttr::verify_partial_status() const {
227227
if (itr.first < 0 || itr.first >= process_mesh_.ndim()) {
228228
return false;
229229
}
230-
if (itr.second < ReduceType::kRedSum || itr.second <= ReduceType::kRedAll) {
230+
if (itr.second < ReduceType::kRedSum || itr.second > ReduceType::kRedAll) {
231231
return false;
232232
}
233233
}
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/core/distributed/auto_parallel/r_to_p_reshard_function.h"
16+
17+
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
18+
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
19+
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
20+
#include "paddle/phi/kernels/assign_kernel.h"
21+
#include "paddle/phi/kernels/full_kernel.h"
22+
23+
namespace phi {
24+
namespace distributed {
25+
26+
bool RToPReshardFunction::IsSuitable(const DistTensor& in,
27+
const TensorDistAttr& out_dist_attr) {
28+
bool flag = true;
29+
const auto& in_dist_attr = in.dist_attr();
30+
31+
const auto& in_dims_mapping = in_dist_attr.dims_mapping();
32+
33+
flag &= IsDimsMappingReplicated(in_dims_mapping);
34+
flag &= out_dist_attr.is_partial();
35+
36+
const auto& in_process_mesh = in_dist_attr.process_mesh();
37+
const auto& out_process_mesh = out_dist_attr.process_mesh();
38+
39+
flag &= (in_process_mesh.ndim() == 1);
40+
flag &= (out_process_mesh.ndim() == 1);
41+
flag &= (in_process_mesh == out_process_mesh);
42+
43+
return flag;
44+
}
45+
46+
void RToPReshardFunction::Eval(phi::DeviceContext* dev_ctx,
47+
const DistTensor& in,
48+
const TensorDistAttr& out_dist_attr,
49+
DistTensor* out) {
50+
const auto& out_process_mesh = out_dist_attr.process_mesh();
51+
int64_t local_rank = GetCurRankCoordInMesh(out_process_mesh)[0];
52+
IntArray shape(in.dims().Get(), in.dims().size());
53+
54+
if (local_rank != 0) {
55+
// reset the physical tensor to zero
56+
RESHARD_FUNCTOR(dev_ctx, Full, in.dtype(), shape, 0, GetMutableTensor(out));
57+
} else {
58+
// assign the input value to output
59+
if (phi::CPUContext::classof(dev_ctx)) {
60+
Assign(static_cast<const CPUContext&>(*dev_ctx),
61+
in.value(),
62+
GetMutableTensor(out));
63+
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
64+
} else if (phi::GPUContext::classof(dev_ctx)) {
65+
Assign(static_cast<const GPUContext&>(*dev_ctx),
66+
in.value(),
67+
GetMutableTensor(out));
68+
#endif
69+
} else {
70+
PADDLE_THROW(phi::errors::Unimplemented(
71+
"The assign in reshard only supported on CPU and GPU for now."));
72+
}
73+
}
74+
SetDistProps(out, in.dims(), out_dist_attr);
75+
}
76+
77+
REGISTER_RESHARD_FUNC(RToPReshardFunction);
78+
79+
} // namespace distributed
80+
} // namespace phi
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/phi/core/distributed/auto_parallel/reshard_function.h"
18+
19+
namespace phi {
20+
namespace distributed {
21+
22+
class RToPReshardFunction final : public ReshardFunction {
23+
public:
24+
bool IsSuitable(const DistTensor& in,
25+
const TensorDistAttr& out_dist_attr) override;
26+
27+
void Eval(DeviceContext* dev_ctx,
28+
const DistTensor& in,
29+
const TensorDistAttr& out_dist_attr,
30+
DistTensor* out) override;
31+
};
32+
33+
} // namespace distributed
34+
} // namespace phi

paddle/phi/core/distributed/auto_parallel/s_to_r_reshard_function.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ void SToRReshardFunction::Eval(DeviceContext* dev_ctx,
7676
in.value(),
7777
in_process_ids.size(),
7878
GetMutableTensor(out));
79-
8079
std::map<int64_t, int64_t> split_axis_to_mesh_axis =
8180
GetSplitAxisWithDimsMapping(in_dims_mapping);
8281
int64_t split_axis = split_axis_to_mesh_axis.begin()->first;

paddle/phi/kernels/assign_kernel.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@ DenseTensor Assign(const Context& dev_ctx, const DenseTensor& x) {
3838
return out;
3939
}
4040

41+
template <typename Context>
42+
void Assign(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) {
43+
MetaTensor meta_out(out);
44+
MetaTensor meta_x(x);
45+
UnchangedInferMeta(meta_x, &meta_out);
46+
AssignKernel<Context>(dev_ctx, x, out);
47+
}
48+
4149
// In order to be compatible with the `AsDispensable` input in the original
4250
// assign op maker, the input parameter here needs to be dispensable, but
4351
// this looks weird

test/auto_parallel/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
8585
py_test_modules(test_reshard_r_to_s MODULES test_reshard_r_to_s)
8686
set_tests_properties(test_reshard_r_to_s
8787
PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 100)
88+
py_test_modules(test_reshard_r_to_p MODULES test_reshard_r_to_p)
89+
set_tests_properties(test_reshard_r_to_p
90+
PROPERTIES LABELS "RUN_TYPE=EXECLUSIVE" TIMEOUT 100)
8891
# End of unittests WITH multi cards and timeout
8992

9093
# NOTE(zyl): unittests WITH multi cards and WITHOUT timeout

test/auto_parallel/reshard_r_to_p.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
17+
import numpy as np
18+
19+
import paddle
20+
import paddle.distributed as dist
21+
from paddle.framework import core
22+
23+
24+
class TestReshardRToP:
25+
def __init__(self):
26+
self._shape = eval(os.getenv("shape"))
27+
self._dtype = os.getenv("dtype")
28+
self._seeds = eval(os.getenv("seeds"))
29+
self._backend = os.getenv("backend")
30+
self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"])
31+
32+
def run_test_case(self):
33+
if self._backend == "cpu":
34+
paddle.set_device("cpu")
35+
place = paddle.CPUPlace()
36+
elif self._backend == "gpu":
37+
place = paddle.CUDAPlace(dist.get_rank())
38+
39+
dev_ctx = core.DeviceContext.create(place)
40+
a = paddle.ones(self._shape)
41+
42+
in_shard_specs = [None for i in range(len(self._shape))]
43+
out_shard_specs = [None for i in range(len(self._shape))]
44+
45+
dist_attr = dist.DistAttr(
46+
mesh=self._mesh, sharding_specs=in_shard_specs
47+
)
48+
out_dist_attr = dist.DistAttr(
49+
mesh=self._mesh, sharding_specs=out_shard_specs
50+
)
51+
out_dist_attr._set_partial_dims([0])
52+
53+
input_tensor = dist.shard_tensor(a, dist_attr=dist_attr)
54+
55+
reshard_func = core.RToPReshardFunction()
56+
assert reshard_func.is_suitable(input_tensor, out_dist_attr)
57+
58+
out = reshard_func.eval(dev_ctx, input_tensor, out_dist_attr)
59+
60+
if dist.get_rank() == 0:
61+
np.testing.assert_equal(
62+
out._local_value().numpy(), input_tensor.numpy()
63+
)
64+
else:
65+
zeros = paddle.zeros(self._shape)
66+
np.testing.assert_equal(out._local_value().numpy(), zeros.numpy())
67+
68+
assert np.equal(out.shape, input_tensor.shape).all()
69+
assert np.equal(out._local_shape, input_tensor._local_shape).all()
70+
71+
72+
if __name__ == '__main__':
73+
TestReshardRToP().run_test_case()

0 commit comments

Comments
 (0)