Skip to content

Commit 71168da

Browse files
authored
[Cherry Pick] Bug fix and speedup dygraph multi-cards on v1.5 (#19298)
* add warning info for CPU_NUM test=develop * update dygraph parallel.py test=develop * prune the feed op in compiler test=release/1.5 * remove compile from PE test=develop * test CUDAPinnedPlace in reader test=release/1.5
1 parent 3b5f354 commit 71168da

File tree

9 files changed

+226
-37
lines changed

9 files changed

+226
-37
lines changed

paddle/fluid/framework/tensor_util.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,8 @@ void TensorCopy(const Tensor& src, const platform::Place& dst_place,
9999
PADDLE_THROW("ctx is not belong to dst_gpu_place or src_gpu_place.");
100100
}
101101
}
102+
} else {
103+
PADDLE_THROW("Copy from %s to %s is not supported.", src_place, dst_place);
102104
}
103105
#endif
104106
}
@@ -166,6 +168,8 @@ void TensorCopySync(const Tensor& src, const platform::Place& dst_place,
166168
auto dst_gpu_place = boost::get<platform::CUDAPlace>(dst_place);
167169
memory::Copy(dst_gpu_place, dst_ptr, src_pinned_place, src_ptr, size,
168170
nullptr);
171+
} else {
172+
PADDLE_THROW("Copy from %s to %s is not supported.", src_place, dst_place);
169173
}
170174
#endif
171175
}

paddle/fluid/operators/assign_op.cc

Lines changed: 39 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -66,27 +66,47 @@ class AssignFunctor {
6666
const platform::DeviceContext &dev_ctx_;
6767
};
6868

69-
class AssignOp : public framework::OperatorBase {
69+
class AssignOp : public framework::OperatorWithKernel {
7070
public:
7171
AssignOp(const std::string &type, const framework::VariableNameMap &inputs,
7272
const framework::VariableNameMap &outputs,
7373
const framework::AttributeMap &attrs)
74-
: OperatorBase(type, inputs, outputs, attrs) {}
74+
: OperatorWithKernel(type, inputs, outputs, attrs) {}
7575

76-
private:
77-
void RunImpl(const framework::Scope &scope,
78-
const platform::Place &place) const override {
79-
auto *x = scope.FindVar(Input("X"));
76+
void InferShape(framework::InferShapeContext *ctx) const override {
77+
if (ctx->HasInput("X")) {
78+
auto type = ctx->GetInputsVarType("X")[0];
79+
if (type == framework::proto::VarType::SELECTED_ROWS ||
80+
type == framework::proto::VarType::LOD_TENSOR) {
81+
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
82+
if (type == framework::proto::VarType::LOD_TENSOR) {
83+
ctx->ShareLoD("X", /*->*/ "Out");
84+
}
85+
}
86+
}
87+
}
88+
89+
protected:
90+
framework::OpKernelType GetExpectedKernelType(
91+
const framework::ExecutionContext &ctx) const override {
92+
return framework::OpKernelType(ctx.Input<framework::LoDTensor>("X")->type(),
93+
ctx.device_context());
94+
}
95+
};
96+
97+
class AssignKernel {
98+
public:
99+
void operator()(const framework::ExecutionContext &ctx) const {
100+
auto *x = ctx.InputVar("X");
80101
if (x == nullptr) {
81102
return;
82103
}
83-
auto *out = scope.FindVar(Output("Out"));
104+
auto *out = ctx.OutputVar("Out");
84105
PADDLE_ENFORCE(
85106
out != nullptr,
86107
"The Output(Out) should not be null if the Input(X) is set.");
87-
88108
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
89-
auto &dev_ctx = *pool.Get(place);
109+
auto &dev_ctx = *pool.Get(ctx.GetPlace());
90110

91111
framework::VisitVarType(*x, AssignFunctor(out, dev_ctx));
92112
}
@@ -110,19 +130,6 @@ raise error if the type is not listed above.
110130
}
111131
};
112132

113-
class AssignInferShape : public framework::InferShapeBase {
114-
public:
115-
void operator()(framework::InferShapeContext *context) const override {
116-
if (context->HasInput("X")) {
117-
auto type = context->GetInputsVarType("X")[0];
118-
if (type == framework::proto::VarType::SELECTED_ROWS ||
119-
type == framework::proto::VarType::LOD_TENSOR) {
120-
context->SetOutputDim("Out", context->GetInputDim("X"));
121-
}
122-
}
123-
}
124-
};
125-
126133
class AssignGradMaker : public framework::SingleGradOpDescMaker {
127134
public:
128135
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
@@ -142,4 +149,13 @@ class AssignGradMaker : public framework::SingleGradOpDescMaker {
142149

143150
namespace ops = paddle::operators;
144151
REGISTER_OPERATOR(assign, ops::AssignOp, ops::AssignGradMaker,
145-
ops::AssignInferShape, ops::AssignOpProtoMaker);
152+
ops::AssignOpProtoMaker);
153+
REGISTER_OP_CPU_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double,
154+
ops::AssignKernel, int, ops::AssignKernel,
155+
int64_t, ops::AssignKernel);
156+
157+
#ifdef PADDLE_WITH_CUDA
158+
REGISTER_OP_CUDA_KERNEL_FUNCTOR(assign, float, ops::AssignKernel, double,
159+
ops::AssignKernel, int, ops::AssignKernel,
160+
int64_t, ops::AssignKernel);
161+
#endif

paddle/fluid/operators/reader/buffered_reader.cc

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,9 +128,18 @@ void BufferedReader::ReadAsync(size_t i) {
128128
boost::get<platform::CUDAPlace>(cpu_place), cpu_ptr,
129129
size, stream_);
130130
} else {
131+
platform::CUDAPinnedPlace cuda_pinned_place;
132+
framework::LoDTensor cuda_pinned_tensor;
133+
cuda_pinned_tensor.Resize(cpu[i].dims());
134+
auto cuda_pinned_ptr =
135+
cuda_pinned_tensor.mutable_data(cuda_pinned_place, cpu[i].type());
136+
memory::Copy(cuda_pinned_place, cuda_pinned_ptr,
137+
boost::get<platform::CPUPlace>(cpu_place), cpu_ptr,
138+
size);
131139
memory::Copy(boost::get<platform::CUDAPlace>(place_), gpu_ptr,
132-
boost::get<platform::CPUPlace>(cpu_place), cpu_ptr, size,
133-
stream_);
140+
cuda_pinned_place, cuda_pinned_ptr, size, stream_);
141+
PADDLE_ENFORCE(cudaStreamSynchronize(stream_),
142+
"cuda stream sync error.");
134143
}
135144
gpu[i].set_lod(cpu[i].lod());
136145
}

paddle/fluid/pybind/pybind.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,13 @@ limitations under the License. */
6464
#ifndef _WIN32
6565
#include "paddle/fluid/pybind/nccl_wrapper_py.h"
6666
#endif
67+
#include "paddle/fluid/framework/data_type.h"
6768
#include "paddle/fluid/pybind/protobuf.h"
6869
#include "paddle/fluid/pybind/pybind.h" // NOLINT
6970
#include "paddle/fluid/pybind/reader_py.h"
7071
#include "paddle/fluid/pybind/recordio.h"
7172
#include "paddle/fluid/pybind/tensor_py.h"
7273
#include "paddle/fluid/string/to_string.h"
73-
7474
#ifdef PADDLE_WITH_CUDA
7575
#ifndef _WIN32
7676
#include "paddle/fluid/operators/nccl/nccl_gpu_common.h"
@@ -1118,6 +1118,8 @@ All parameter, weight, gradient are variables in Paddle.
11181118
return std::shared_ptr<framework::ir::Pass>(std::move(pass));
11191119
});
11201120

1121+
m.def("size_of_dtype", framework::SizeOfType);
1122+
11211123
py::class_<ir::Pass, std::shared_ptr<ir::Pass>> pass(m, "Pass");
11221124
pass.def(py::init())
11231125
.def("has", &ir::Pass::Has)

python/paddle/fluid/compiler.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,15 @@ def _is_pserver_mode(main_program):
4646
return False
4747

4848

49+
def _prune_feed_ops(program):
50+
# prune the feed ops in the program.
51+
pop_idx = []
52+
for i, op in enumerate(program.global_block().ops):
53+
if op.type == "feed": pop_idx.append(i)
54+
for index in pop_idx[::-1]:
55+
program.global_block()._remove_op(index)
56+
57+
4958
class CompiledProgram(object):
5059
"""
5160
Compiles to Graph for execution.
@@ -101,6 +110,7 @@ def __init__(self, program_or_graph):
101110
# don't not create a new program here.
102111
self._program = None
103112
elif isinstance(program_or_graph, framework.Program):
113+
_prune_feed_ops(program_or_graph)
104114
self._graph = core.Graph(program_or_graph.desc)
105115
self._program = program_or_graph
106116
else:
@@ -274,8 +284,6 @@ def _compile_data_parallel(self, use_cuda=False, scope=None):
274284
"share_vars_from is not compiled and run, so there is no "
275285
"var to share.")
276286
self._local_scopes = self._share_vars_from._executor.local_scopes()
277-
# drop the local_exe_scopes of the previous parallel_executor
278-
self._share_vars_from._executor.drop_local_exe_scopes()
279287
else:
280288
assert scope is not None, ""
281289
self._local_scopes = []

python/paddle/fluid/dygraph/parallel.py

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import os
1515
import six
1616
import numpy as np
17-
17+
from collections import OrderedDict
1818
from .. import core
1919
from . import layers
2020
from . import parallel_helper
@@ -36,7 +36,7 @@ def prepare_context(strategy=None):
3636
strategy.current_endpoint = Env().current_endpoint
3737
if strategy.nranks < 2:
3838
return
39-
assert framework.in_dygraph_mode() is True,\
39+
assert framework.in_dygraph_mode() is True, \
4040
"dygraph.parallel.prepare_context should be used with dygrahp mode."
4141
place = framework._current_expected_place()
4242
assert place is not None, \
@@ -168,13 +168,46 @@ def scale_loss(self, loss):
168168
loss = loss / loss_scale
169169
return loss
170170

171+
def _coalesce_tensors(self, var_groups):
172+
from ..layers import nn
173+
coalesced_grads_and_grad_vars = []
174+
for group_id, grad_vars in var_groups.items():
175+
flattened_vars = []
176+
g_var_shapes = []
177+
for g_var in grad_vars:
178+
g_var_shapes.append(g_var.shape)
179+
flattened_vars.append(
180+
nn.reshape(
181+
x=g_var, shape=[np.prod(g_var.shape)], inplace=True))
182+
coalesced_grad = nn.concat(flattened_vars)
183+
coalesced_grads_and_grad_vars.append(
184+
[coalesced_grad, grad_vars, g_var_shapes])
185+
return coalesced_grads_and_grad_vars
186+
187+
def _split_tensors(self, coalesced_grads_and_grad_vars):
188+
from ..layers import nn
189+
for coalesced_grad, origin_grad_vars, grad_shapes in coalesced_grads_and_grad_vars:
190+
grad_var_len = [np.prod(g_shape) for g_shape in grad_shapes]
191+
splited_vars = nn.split(
192+
coalesced_grad, num_or_sections=grad_var_len, dim=0)
193+
reshaped_grad_vars = []
194+
for g_var, g_shape in zip(splited_vars, grad_shapes):
195+
reshaped_grad_vars.append(
196+
nn.reshape(
197+
x=g_var, shape=g_shape, inplace=True))
198+
for origin_g_var, reshaped_g_var in zip(origin_grad_vars,
199+
reshaped_grad_vars):
200+
nn.assign(input=reshaped_g_var, output=origin_g_var)
201+
171202
def apply_collective_grads(self):
172203
"""
173204
AllReduce the Parameters' gradient.
174205
"""
175206
if not self._is_data_parallel_mode():
176207
return
177208

209+
grad_var_set = set()
210+
grad_vars = []
178211
for param in self._layers.parameters():
179212
# NOTE(zcd): The grad_ivar maybe no generated.
180213
if param.trainable and param._ivar._grad_ivar():
@@ -183,7 +216,36 @@ def apply_collective_grads(self):
183216
name=param._ivar._grad_name(),
184217
stop_gradient=True,
185218
ivar=param._ivar._grad_ivar())
186-
collective._allreduce(g_var, g_var, sync_mode=True)
219+
grad_vars.append(g_var)
220+
assert g_var not in grad_var_set
221+
grad_var_set.add(g_var)
222+
223+
# FIXME(zcd): the type of the var should be LoDTensor, i.e
224+
# the gradients should be dense, otherwise, the following
225+
# logic should be updated.
226+
# 128 MB as a group
227+
mega_bytes = 128 * 1024 * 1024
228+
group_idx = 0
229+
memory_counter = 0
230+
grad_var_groups = OrderedDict()
231+
dtype = grad_vars[0].dtype
232+
for g_var in grad_vars:
233+
# Note: the dtype of the same group should be the same.
234+
bytes = np.prod(g_var.shape) * core.size_of_dtype(g_var.dtype)
235+
if memory_counter < mega_bytes and dtype == g_var.dtype:
236+
memory_counter += bytes
237+
else:
238+
memory_counter = bytes
239+
group_idx += 1
240+
grad_var_groups.setdefault(group_idx, []).append(g_var)
241+
242+
coalesced_grads_and_vars = self._coalesce_tensors(grad_var_groups)
243+
244+
for coalesced_grad, g_vars, g_shapes in coalesced_grads_and_vars:
245+
collective._allreduce(
246+
coalesced_grad, coalesced_grad, sync_mode=False)
247+
248+
self._split_tensors(coalesced_grads_and_vars)
187249

188250
def _is_data_parallel_mode(self):
189251
return self._strategy.nranks > 1

python/paddle/fluid/framework.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,14 @@ def _current_expected_place():
8585

8686
def _cpu_num():
8787
if "CPU_NUM" not in os.environ.keys():
88-
sys.stderr.write(
89-
'The CPU_NUM is not specified, you should set CPU_NUM in '
90-
'the environment variable list, i.e export CPU_NUM=1. CPU_NUM '
91-
'indicates that how many CPUPlace are used in the current task.\n'
92-
'!!! The default number of CPUPlaces is 1.\n\n')
88+
if multiprocessing.cpu_count() > 1:
89+
sys.stderr.write(
90+
'!!! The CPU_NUM is not specified, you should set CPU_NUM in the environment variable list.\n'
91+
'CPU_NUM indicates that how many CPUPlace are used in the current task.\n'
92+
'And if this parameter are set as N (equal to the number of physical CPU core) the program may be faster.\n\n'
93+
'export CPU_NUM={} # for example, set CPU_NUM as number of physical CPU core which is {}.\n\n'
94+
'!!! The default number of CPU_NUM=1.\n'.format(
95+
multiprocessing.cpu_count(), multiprocessing.cpu_count()))
9396
os.environ['CPU_NUM'] = str(1)
9497
cpu_num = os.environ.get('CPU_NUM')
9598
return int(cpu_num)

python/paddle/fluid/parallel_executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def __init__(self,
163163
assert isinstance(
164164
share_vars_from, ParallelExecutor
165165
), "The share_vars_from should be ParallelExecutor."
166+
166167
self._compiled_program.with_data_parallel(
167168
loss_name=loss_name,
168169
build_strategy=build_strategy,
@@ -172,7 +173,6 @@ def __init__(self,
172173

173174
self._place = core.CUDAPlace(0) if use_cuda else core.CPUPlace()
174175
self._exe = executor.Executor(self._place)
175-
self._compiled_program._compile(place=self._place, scope=self._scope)
176176

177177
def run(self, fetch_list, feed=None, feed_dict=None, return_numpy=True):
178178
"""

0 commit comments

Comments
 (0)