Skip to content

Commit dca9941

Browse files
committed
pass size when copy
1 parent 67d6f3a commit dca9941

File tree

4 files changed

+12
-8
lines changed

4 files changed

+12
-8
lines changed

paddle/fluid/operators/concat_op.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class ConcatKernel : public framework::OpKernel<T> {
3838
auto in_stride = framework::stride_numel(in->dims());
3939
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
4040
out->data<T>() + output_offset, out_stride,
41-
in->data<T>(), in_stride);
41+
in->data<T>(), in_stride, in_stride[axis]);
4242
output_offset += in_stride[axis];
4343
}
4444
}
@@ -59,7 +59,7 @@ class ConcatGradKernel : public framework::OpKernel<T> {
5959
auto out_stride = framework::stride_numel(out->dims());
6060
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
6161
out_stride, in->data<T>() + input_offset,
62-
in_stride);
62+
in_stride, out_stride[axis]);
6363
input_offset += out_stride[axis];
6464
}
6565
}

paddle/fluid/operators/split_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class SplitOpKernel : public framework::OpKernel<T> {
3838
auto out_stride = framework::stride_numel(out->dims());
3939
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
4040
out_stride, in->data<T>() + input_offset,
41-
in_stride);
41+
in_stride, out_stride[axis]);
4242
input_offset += out_stride[axis];
4343
}
4444
}

paddle/fluid/operators/strided_memcpy.h

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,11 +54,11 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx,
5454
int64_t axis, T* dst,
5555
const framework::DDim& dst_stride_numel,
5656
const T* src,
57-
const framework::DDim& src_stride_numel) {
57+
const framework::DDim& src_stride_numel,
58+
int64_t size) {
5859
int64_t before = dst_stride_numel[0] / dst_stride_numel[axis];
5960
int64_t src_after = src_stride_numel[axis];
6061
int64_t dst_after = dst_stride_numel[axis];
61-
int64_t copy_size = std::min(src_after, dst_after);
6262
auto place = ctx.GetPlace();
6363

6464
PADDLE_ENFORCE_EQ(src_stride_numel.size(), dst_stride_numel.size(),
@@ -83,15 +83,14 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx,
8383
if (platform::is_cpu_place(place)) {
8484
auto& cpu_place = boost::get<platform::CPUPlace>(place);
8585
memory::Copy(cpu_place, dst + i * dst_after, cpu_place,
86-
src + i * src_after, sizeof(T) * copy_size);
86+
src + i * src_after, sizeof(T) * size);
8787
} else {
8888
#ifdef PADDLE_WITH_CUDA
8989
auto& gpu_place = boost::get<platform::CUDAPlace>(place);
9090
auto& cuda_ctx =
9191
reinterpret_cast<const platform::CUDADeviceContext&>(ctx);
9292
memory::Copy(gpu_place, dst + i * dst_after, gpu_place,
93-
src + i * src_after, sizeof(T) * copy_size,
94-
cuda_ctx.stream());
93+
src + i * src_after, sizeof(T) * size, cuda_ctx.stream());
9594
#else
9695
PADDLE_THROW("Paddle is not compiled with GPU");
9796
#endif

python/paddle/v2/fluid/distribute_transpiler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def split_dense_variable(var_list,
121121
block_size += dim1 - remains
122122
# update split_count after aligning
123123
split_count = int(math.ceil(var_numel / float(block_size)))
124+
print("###split var ", var.name, var.shape, block_size, split_count)
124125
for block_id in xrange(split_count):
125126
curr_block_size = min(block_size, var_numel - (
126127
(block_id) * block_size))
@@ -255,13 +256,15 @@ def _create_vars_from_blocklist(self, program, block_list):
255256
splited_shape = [rows]
256257
if len(orig_shape) >= 2:
257258
splited_shape.extend(orig_shape[1:])
259+
print("###splited: ", size, rows, splited_shape)
258260
var = program.global_block().create_var(
259261
name="%s.block%d" % (varname, i),
260262
psersistable=False,
261263
dtype=orig_var.dtype,
262264
type=orig_var.type,
263265
shape=splited_shape) # flattend splited var
264266
var_mapping[varname].append(var)
267+
print("###created split var ", var)
265268
return var_mapping
266269

267270
def _clone_var(self, block, var):
@@ -528,6 +531,8 @@ def get_pserver_program(self, endpoint):
528531
"""
529532
# step5
530533
pserver_program = Program()
534+
print("param mapping on pserver: #### ",
535+
self.param_grad_ep_mapping[endpoint]["params"])
531536
for v in self.param_grad_ep_mapping[endpoint]["params"]:
532537
self._clone_var(pserver_program.global_block(), v)
533538
for v in self.param_grad_ep_mapping[endpoint]["grads"]:

0 commit comments

Comments
 (0)