Skip to content

Commit 0240bb7

Browse files
authored
Merge pull request #8516 from QiJune/memopt_multi_gpu
make memory optimization module compatible with parallel_do
2 parents acbda44 + 191d8dc commit 0240bb7

File tree

2 files changed

+82
-58
lines changed

2 files changed

+82
-58
lines changed

python/paddle/fluid/memory_optimization_transpiler.py

Lines changed: 61 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
core.VarDesc.VarType.BOOL: 1
3030
}
3131

32+
sub_block_ops = ["while", "while_grad", "parallel_do", "parallel_do_grad"]
33+
3234

3335
class ControlFlowGraph(object):
3436
def __init__(self, Program, ops, forward_num, skip_opt):
@@ -141,7 +143,7 @@ def check_var_validity(block_desc, x, is_forward):
141143
self.pool = []
142144
for i in range(self.op_size):
143145
op = self._ops[i]
144-
if op.type() == "while" or op.type() == "while_grad":
146+
if op.type() in sub_block_ops:
145147
continue
146148
block_desc = op.block()
147149
is_forward = i < self._forward_num
@@ -198,67 +200,75 @@ def check_var_validity(block_desc, x, is_forward):
198200
block_desc, var_name, is_forward).shape()))
199201

200202

201-
def get_cfgs(input_program):
203+
def _process_sub_block_pair(pdesc, sub_block_pair):
202204
ops_list = []
203-
pdesc = input_program.get_desc()
204205
block_desc = pdesc.block(0)
205206
op_size = block_desc.op_size()
206-
# Get global block ops
207-
ops_list.append(
208-
([block_desc.op(i) for i in range(op_size)], op_size, set()))
209-
210-
while_sub_block_ids = []
211-
while_grad_sub_block_ids = []
212-
while_block_id_pair = []
213-
while_op_dict = {}
207+
for fwd_op, bwd_op in sub_block_pair:
208+
sub_block_ids = []
209+
grad_sub_block_ids = []
210+
sub_block_id_pair = []
211+
sub_op_dict = {}
212+
for i in range(op_size):
213+
op = block_desc.op(i)
214+
if op.type() == fwd_op:
215+
sub_block_ids.append(op.attr("sub_block").id)
216+
sub_op_dict[op.attr("sub_block").id] = op
217+
elif op.type() == bwd_op:
218+
grad_sub_block_ids.append(op.attr("sub_block").id)
219+
sub_op_dict[op.attr("sub_block").id] = op
214220

215-
for i in range(op_size):
216-
op = block_desc.op(i)
217-
if op.type() == "while":
218-
while_sub_block_ids.append(op.attr("sub_block").id)
219-
while_op_dict[op.attr("sub_block").id] = op
220-
elif op.type() == "while_grad":
221-
while_grad_sub_block_ids.append(op.attr("sub_block").id)
222-
while_op_dict[op.attr("sub_block").id] = op
221+
# Find fwd_op/bwd_op block pair
222+
for grad_id in grad_sub_block_ids:
223+
fwd_id = pdesc.block(grad_id).get_forward_block_idx()
224+
if fwd_id in sub_block_ids:
225+
sub_block_id_pair.append((fwd_id, grad_id))
226+
sub_block_ids.remove(fwd_id)
223227

224-
# Find while/while_grad block pair
225-
for grad_id in while_grad_sub_block_ids:
226-
forward_id = pdesc.block(grad_id).get_forward_block_idx()
227-
if forward_id in while_sub_block_ids:
228-
while_block_id_pair.append((forward_id, grad_id))
229-
while_sub_block_ids.remove(forward_id)
228+
# Get fwd_op/bwd_op block ops
229+
for fwd_id, grad_id in sub_block_id_pair:
230+
sub_block_ops = []
231+
sub_block = pdesc.block(fwd_id)
232+
block_op_size = sub_block.op_size()
233+
for i in range(block_op_size):
234+
sub_block_ops.append(sub_block.op(i))
230235

231-
# Get while/while_grad block ops
232-
for forward_id, grad_id in while_block_id_pair:
233-
while_block_ops = []
234-
while_block = pdesc.block(forward_id)
235-
while_block_op_size = while_block.op_size()
236-
for i in range(while_block_op_size):
237-
while_block_ops.append(while_block.op(i))
236+
grad_sub_block = pdesc.block(grad_id)
237+
grad_sub_block_op_size = grad_sub_block.op_size()
238+
for i in range(grad_sub_block_op_size):
239+
sub_block_ops.append(grad_sub_block.op(i))
238240

239-
while_grad_block = pdesc.block(grad_id)
240-
while_grad_block_op_size = while_grad_block.op_size()
241-
for i in range(while_grad_block_op_size):
242-
while_block_ops.append(while_grad_block.op(i))
241+
sub_op_output = set()
242+
sub_op_output.update(sub_op_dict[fwd_id].output_arg_names())
243+
sub_op_output.update(sub_op_dict[grad_id].output_arg_names())
244+
ops_list.append((sub_block_ops, block_op_size, sub_op_output))
243245

244-
while_op_output = set()
245-
while_op_output.update(while_op_dict[forward_id].output_arg_names())
246-
while_op_output.update(while_op_dict[grad_id].output_arg_names())
246+
# Process rest fwd_op block ops
247+
for fwd_id in sub_block_ids:
248+
sub_block_ops = []
249+
sub_block = pdesc.block(fwd_id)
250+
sub_block_op_size = sub_block.op_size()
251+
for i in range(sub_block_op_size):
252+
sub_block_ops.append(sub_block.op(i))
253+
sub_op_output = set()
254+
sub_op_output.update(sub_op_dict[fwd_id].output_arg_names())
255+
ops_list.append((sub_block_ops, sub_block_op_size, sub_op_output))
256+
return ops_list
247257

248-
ops_list.append((while_block_ops, while_block_op_size, while_op_output))
249258

250-
# Process rest while block ops
251-
for forward_id in while_sub_block_ids:
252-
while_block_ops = []
253-
while_block = pdesc.block(forward_id)
254-
while_block_op_size = while_block.op_size()
255-
for i in range(while_block_op_size):
256-
while_block_ops.append(while_block.op(i))
259+
def _get_cfgs(input_program):
260+
ops_list = []
261+
pdesc = input_program.get_desc()
262+
block_desc = pdesc.block(0)
263+
op_size = block_desc.op_size()
264+
# Get global block ops
265+
ops_list.append(
266+
([block_desc.op(i) for i in range(op_size)], op_size, set()))
257267

258-
while_op_output = set()
259-
while_op_output.update(while_op_dict[forward_id].output_arg_names())
268+
sub_block_pair = [("while", "while_grad"), ("parallel_do",
269+
"parallel_do_grad")]
260270

261-
ops_list.append((while_block_ops, while_block_op_size, while_op_output))
271+
ops_list.extend(_process_sub_block_pair(pdesc, sub_block_pair))
262272

263273
cfgs = [
264274
ControlFlowGraph(input_program, ops, forward_num, skip_opt)
@@ -268,6 +278,6 @@ def get_cfgs(input_program):
268278

269279

270280
def memory_optimize(input_program):
271-
cfgs = get_cfgs(input_program)
281+
cfgs = _get_cfgs(input_program)
272282
for cfg in cfgs:
273283
cfg.memory_optimize()

python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,29 @@
2424
fluid.default_startup_program().random_seed = 111
2525

2626
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
27-
28-
y_predict = fluid.layers.fc(input=x, size=1, act=None)
29-
3027
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
3128

32-
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
33-
avg_cost = fluid.layers.mean(cost)
29+
device_type = 'CPU'
30+
use_nccl = False
31+
place = fluid.CPUPlace()
32+
if fluid.core.is_compiled_with_cuda():
33+
device_type = 'CUDA'
34+
use_nccl = True
35+
place = fluid.CUDAPlace(0)
3436

35-
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1)
37+
places = fluid.layers.get_places(device_count=0, device_type=device_type)
38+
pd = fluid.layers.ParallelDo(places, use_nccl=use_nccl)
39+
with pd.do():
40+
x_ = pd.read_input(x)
41+
y_ = pd.read_input(y)
42+
y_predict = fluid.layers.fc(input=x_, size=1, act=None)
43+
cost = fluid.layers.square_error_cost(input=y_predict, label=y_)
44+
avg_cost = fluid.layers.mean(x=cost)
45+
pd.write_output(avg_cost)
46+
47+
cost = pd()
48+
avg_cost = fluid.layers.mean(x=cost)
49+
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01)
3650
sgd_optimizer.minimize(avg_cost)
3751

3852
fluid.memory_optimize(fluid.default_main_program())
@@ -48,7 +62,6 @@
4862
# paddle.dataset.uci_housing.train(), buf_size=500),
4963
# batch_size=BATCH_SIZE)
5064

51-
place = fluid.CPUPlace()
5265
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
5366
exe = fluid.Executor(place)
5467

@@ -65,6 +78,7 @@
6578

6679
if avg_loss_value[0] < 10.0:
6780
exit(0) # if avg cost less than 10.0, we think our code is good.
81+
print avg_loss_value[0]
6882
if math.isnan(float(avg_loss_value)):
6983
sys.exit("got NaN loss, training failed.")
7084
exit(1)

0 commit comments

Comments
 (0)