Skip to content

Commit 9942565

Browse files
authored
Merge pull request #8386 from typhoonzero/fix_dist_transpiler_develop
Fix dist transpiler develop
2 parents da02a58 + dca9941 commit 9942565

File tree

5 files changed

+34
-56
lines changed

5 files changed

+34
-56
lines changed

paddle/fluid/operators/concat_op.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class ConcatKernel : public framework::OpKernel<T> {
3838
auto in_stride = framework::stride_numel(in->dims());
3939
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis,
4040
out->data<T>() + output_offset, out_stride,
41-
in->data<T>(), in_stride);
41+
in->data<T>(), in_stride, in_stride[axis]);
4242
output_offset += in_stride[axis];
4343
}
4444
}
@@ -59,7 +59,7 @@ class ConcatGradKernel : public framework::OpKernel<T> {
5959
auto out_stride = framework::stride_numel(out->dims());
6060
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
6161
out_stride, in->data<T>() + input_offset,
62-
in_stride);
62+
in_stride, out_stride[axis]);
6363
input_offset += out_stride[axis];
6464
}
6565
}

paddle/fluid/operators/split_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class SplitOpKernel : public framework::OpKernel<T> {
3838
auto out_stride = framework::stride_numel(out->dims());
3939
StridedNumelCopyWithAxis<T>(ctx.device_context(), axis, out->data<T>(),
4040
out_stride, in->data<T>() + input_offset,
41-
in_stride);
41+
in_stride, out_stride[axis]);
4242
input_offset += out_stride[axis];
4343
}
4444
}

paddle/fluid/operators/strided_memcpy.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx,
5454
int64_t axis, T* dst,
5555
const framework::DDim& dst_stride_numel,
5656
const T* src,
57-
const framework::DDim& src_stride_numel) {
57+
const framework::DDim& src_stride_numel,
58+
int64_t size) {
5859
int64_t before = dst_stride_numel[0] / dst_stride_numel[axis];
5960
int64_t src_after = src_stride_numel[axis];
6061
int64_t dst_after = dst_stride_numel[axis];
@@ -82,15 +83,14 @@ inline void StridedNumelCopyWithAxis(const platform::DeviceContext& ctx,
8283
if (platform::is_cpu_place(place)) {
8384
auto& cpu_place = boost::get<platform::CPUPlace>(place);
8485
memory::Copy(cpu_place, dst + i * dst_after, cpu_place,
85-
src + i * src_after, sizeof(T) * src_after);
86+
src + i * src_after, sizeof(T) * size);
8687
} else {
8788
#ifdef PADDLE_WITH_CUDA
8889
auto& gpu_place = boost::get<platform::CUDAPlace>(place);
8990
auto& cuda_ctx =
9091
reinterpret_cast<const platform::CUDADeviceContext&>(ctx);
9192
memory::Copy(gpu_place, dst + i * dst_after, gpu_place,
92-
src + i * src_after, sizeof(T) * src_after,
93-
cuda_ctx.stream());
93+
src + i * src_after, sizeof(T) * size, cuda_ctx.stream());
9494
#else
9595
PADDLE_THROW("Paddle is not compiled with GPU");
9696
#endif

python/paddle/v2/fluid/distribute_transpiler.py

Lines changed: 27 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def split_dense_variable(var_list,
121121
block_size += dim1 - remains
122122
# update split_count after aligning
123123
split_count = int(math.ceil(var_numel / float(block_size)))
124+
print("###split var ", var.name, var.shape, block_size, split_count)
124125
for block_id in xrange(split_count):
125126
curr_block_size = min(block_size, var_numel - (
126127
(block_id) * block_size))
@@ -191,7 +192,6 @@ def transpile(self,
191192
for b in param_blocks:
192193
varname, block_id, _ = b.split(":")
193194
send_outputs.append(param_var_mapping[varname][int(block_id)])
194-
195195
# let send_op know which endpoint to send which var to, eplist has the same
196196
# order as send_inputs.
197197
eplist = split_method(send_inputs, pserver_endpoints)
@@ -230,21 +230,6 @@ def transpile(self,
230230
outputs={"Out": [orig_param]},
231231
attrs={"axis": 0})
232232

233-
self.lr_param_mapping = self._create_lr_param_mapping()
234-
235-
def _create_lr_param_mapping(self):
236-
lr_mapping = dict()
237-
for _, opt_op in enumerate(self.optimize_ops):
238-
if not opt_op.inputs or not opt_op.inputs.has_key("LearningRate") \
239-
or not opt_op.inputs.has_key("Param"):
240-
continue
241-
lr = opt_op.inputs["LearningRate"].name
242-
param = opt_op.inputs["Param"].name
243-
if not lr_mapping.has_key(lr):
244-
lr_mapping.update({lr: list()})
245-
lr_mapping[lr].append(param)
246-
return lr_mapping
247-
248233
def _create_vars_from_blocklist(self, program, block_list):
249234
# Create respective variables using the block_list
250235
block_map = dict()
@@ -271,13 +256,15 @@ def _create_vars_from_blocklist(self, program, block_list):
271256
splited_shape = [rows]
272257
if len(orig_shape) >= 2:
273258
splited_shape.extend(orig_shape[1:])
259+
print("###splited: ", size, rows, splited_shape)
274260
var = program.global_block().create_var(
275261
name="%s.block%d" % (varname, i),
276262
psersistable=False,
277263
dtype=orig_var.dtype,
278264
type=orig_var.type,
279265
shape=splited_shape) # flattend splited var
280266
var_mapping[varname].append(var)
267+
print("###created split var ", var)
281268
return var_mapping
282269

283270
def _clone_var(self, block, var):
@@ -369,18 +356,9 @@ def _get_optimizer_input_shape(self, op_type, varkey, orig_shape,
369356
pass
370357
return orig_shape
371358

372-
def _fetch_var_names(self, param_dict):
373-
res = []
374-
if not param_dict:
375-
return res
376-
for _, values in param_dict.iteritems():
377-
if not isinstance(values, list):
378-
values = [values]
379-
res += [v.name for v in values]
380-
return res
381-
382359
def _append_pserver_ops(self, optimize_block, opt_op, endpoint):
383360
program = optimize_block.program
361+
pserver_block = program.global_block()
384362
new_inputs = dict()
385363
# update param/grad shape first, then other inputs like
386364
# moment can use the updated shape
@@ -395,11 +373,11 @@ def _append_pserver_ops(self, optimize_block, opt_op, endpoint):
395373
# do not append this op if current endpoint
396374
# is not dealing with this grad block
397375
return
398-
merged_var = program.global_block().vars[grad_block.name]
376+
merged_var = pserver_block.vars[grad_block.name]
399377
# append merging ops if trainers > 1
400378
if self.trainers > 1:
401379
vars2merge = self._create_var_for_trainers(
402-
program.global_block(), grad_block, self.trainers)
380+
pserver_block, grad_block, self.trainers)
403381
optimize_block.append_op(
404382
type="sum",
405383
inputs={"X": vars2merge},
@@ -419,41 +397,42 @@ def _append_pserver_ops(self, optimize_block, opt_op, endpoint):
419397
break
420398
if not param_block:
421399
return
422-
tmpvar = program.global_block().create_var(
400+
tmpvar = pserver_block.create_var(
423401
name=param_block.name,
424402
persistable=True,
425403
dtype=param_block.dtype,
426404
shape=param_block.shape)
427-
428405
new_inputs[key] = tmpvar
429406
elif key == "LearningRate":
430407
# leraning rate variable has already be created by non-optimize op,
431408
# don't create it once again.
432-
new_inputs[key] = program.global_block().vars[opt_op.input(key)[
433-
0]]
409+
new_inputs[key] = pserver_block.vars[opt_op.input(key)[0]]
434410

435411
for key in opt_op.input_names:
436412
new_shape = None
437413
if key in ["Param", "Grad", "LearningRate"]:
438414
continue
439-
var = program.global_block().vars[opt_op.input(key)[0]]
415+
var = self.program.global_block().vars[opt_op.input(key)[0]]
440416
# update accumulator variable shape
441417
param_shape = new_inputs["Param"].shape
442418
new_shape = self._get_optimizer_input_shape(opt_op.type, key,
443419
var.shape, param_shape)
444-
tmpvar = program.global_block().create_var(
420+
tmpvar = pserver_block.create_var(
445421
name=var.name,
446422
persistable=var.persistable,
447423
dtype=var.dtype,
448424
shape=new_shape)
449425
new_inputs[key] = tmpvar
450426

451427
# change output's ParamOut variable
452-
opt_op.outputs["ParamOut"] = new_inputs["Param"]
428+
outputs = self._get_output_map_from_op(self.program.global_block().vars,
429+
opt_op)
430+
outputs["ParamOut"] = new_inputs["Param"]
431+
453432
optimize_block.append_op(
454433
type=opt_op.type,
455434
inputs=new_inputs,
456-
outputs=opt_op.outputs,
435+
outputs=outputs,
457436
attrs=opt_op.attrs)
458437

459438
def _append_pserver_non_opt_ops(self, optimize_block, opt_op):
@@ -497,11 +476,12 @@ def _is_op_connected(self, op1, op2):
497476
# If one op's input is another op's output or
498477
# one op's output is another op's input, we say
499478
# the two operator is connected.
500-
op1_input_names = self._fetch_var_names(op1.inputs)
501-
op1_output_names = self._fetch_var_names(op1.outputs)
479+
op1_input_names = op1.desc.input_arg_names()
480+
op1_output_names = op1.desc.output_arg_names()
481+
482+
op2_input_names = op2.desc.input_arg_names()
483+
op2_output_names = op2.desc.output_arg_names()
502484

503-
op2_input_names = self._fetch_var_names(op2.inputs)
504-
op2_output_names = self._fetch_var_names(op2.outputs)
505485
if set(op1_output_names) & set(op2_input_names) or \
506486
set(op1_input_names) & set(op2_output_names):
507487
return True
@@ -521,21 +501,21 @@ def _create_ufind(self, optimize_ops):
521501
def _is_opt_op(self, op):
522502
# NOTE: It's a HACK implement.
523503
# optimize op: SGDOptimize, MomentumOptimizer, AdamOptimizer and etc...
524-
if op.inputs and op.inputs.has_key("Param") \
525-
and op.inputs.has_key("LearningRate"):
504+
if "Param" in op.input_names and \
505+
"LearningRate" in op.input_names:
526506
return True
527507
return False
528508

529509
def _is_opt_op_on_pserver(self, endpoint, op):
530510
param_names = [
531511
p.name for p in self.param_grad_ep_mapping[endpoint]["params"]
532512
]
533-
if op.inputs["Param"].name in param_names:
513+
if op.input("Param") in param_names:
534514
return True
535515
else:
536516
for n in param_names:
537-
param = op.inputs["Param"].name
538-
if same_or_split_var(n, param) and n != op.inputs["Param"].name:
517+
param = op.input("Param")[0]
518+
if same_or_split_var(n, param) and n != param:
539519
return True
540520
return False
541521
return False
@@ -551,6 +531,8 @@ def get_pserver_program(self, endpoint):
551531
"""
552532
# step5
553533
pserver_program = Program()
534+
print("param mapping on pserver: #### ",
535+
self.param_grad_ep_mapping[endpoint]["params"])
554536
for v in self.param_grad_ep_mapping[endpoint]["params"]:
555537
self._clone_var(pserver_program.global_block(), v)
556538
for v in self.param_grad_ep_mapping[endpoint]["grads"]:
@@ -564,7 +546,6 @@ def get_pserver_program(self, endpoint):
564546
persistable=True,
565547
dtype=v.dtype,
566548
shape=v.shape)
567-
568549
# step6
569550
optimize_block = pserver_program.create_block(0)
570551
# step 6.1

python/paddle/v2/fluid/framework.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -400,9 +400,6 @@ def __init__(self,
400400
"""
401401
self.block = block
402402
self.desc = desc
403-
# for clone a new operator
404-
self.inputs = inputs
405-
self.outputs = outputs
406403
self.attrs = attrs
407404
if len(self.desc.type()) != 0:
408405
return

0 commit comments

Comments
 (0)