Skip to content

Commit 50cf103

Browse files
committed
make memory optimization module compatible with parallel_do
1 parent bf9ed4a commit 50cf103

File tree

2 files changed

+74
-57
lines changed

2 files changed

+74
-57
lines changed

python/paddle/v2/fluid/memory_optimization_transpiler.py

Lines changed: 61 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
core.VarDesc.VarType.BOOL: 1
3030
}
3131

32+
sub_block_ops = ["while", "while_grad", "parallel_do", "parallel_do_grad"]
33+
3234

3335
class ControlFlowGraph(object):
3436
def __init__(self, Program, ops, forward_num, skip_opt):
@@ -141,7 +143,7 @@ def check_var_validity(block_desc, x, is_forward):
141143
self.pool = []
142144
for i in range(self.op_size):
143145
op = self._ops[i]
144-
if op.type() == "while" or op.type() == "while_grad":
146+
if op.type() in sub_block_ops:
145147
continue
146148
block_desc = op.block()
147149
is_forward = i < self._forward_num
@@ -198,67 +200,75 @@ def check_var_validity(block_desc, x, is_forward):
198200
block_desc, var_name, is_forward).shape()))
199201

200202

201-
def get_cfgs(input_program):
203+
def _process_sub_block_pair(pdesc, sub_block_pair):
202204
ops_list = []
203-
pdesc = input_program.get_desc()
204205
block_desc = pdesc.block(0)
205206
op_size = block_desc.op_size()
206-
# Get global block ops
207-
ops_list.append(
208-
([block_desc.op(i) for i in range(op_size)], op_size, set()))
209-
210-
while_sub_block_ids = []
211-
while_grad_sub_block_ids = []
212-
while_block_id_pair = []
213-
while_op_dict = {}
207+
for fwd_op, bwd_op in sub_block_pair:
208+
sub_block_ids = []
209+
grad_sub_block_ids = []
210+
sub_block_id_pair = []
211+
sub_op_dict = {}
212+
for i in range(op_size):
213+
op = block_desc.op(i)
214+
if op.type() == fwd_op:
215+
sub_block_ids.append(op.attr("sub_block").id)
216+
sub_op_dict[op.attr("sub_block").id] = op
217+
elif op.type() == bwd_op:
218+
grad_sub_block_ids.append(op.attr("sub_block").id)
219+
sub_op_dict[op.attr("sub_block").id] = op
214220

215-
for i in range(op_size):
216-
op = block_desc.op(i)
217-
if op.type() == "while":
218-
while_sub_block_ids.append(op.attr("sub_block").id)
219-
while_op_dict[op.attr("sub_block").id] = op
220-
elif op.type() == "while_grad":
221-
while_grad_sub_block_ids.append(op.attr("sub_block").id)
222-
while_op_dict[op.attr("sub_block").id] = op
221+
# Find fwd_op/bwd_op block pair
222+
for grad_id in grad_sub_block_ids:
223+
parent_id = pdesc.block(grad_id).parent
224+
if parent_id in sub_block_ids:
225+
sub_block_id_pair.append((parent_id, grad_id))
226+
sub_block_ids.remove(parent_id)
223227

224-
# Find while/while_grad block pair
225-
for grad_id in while_grad_sub_block_ids:
226-
parent_id = pdesc.block(grad_id).parent
227-
if parent_id in while_sub_block_ids:
228-
while_block_id_pair.append((parent_id, grad_id))
229-
while_sub_block_ids.remove(parent_id)
228+
# Get fwd_op/bwd_op block ops
229+
for parent_id, grad_id in sub_block_id_pair:
230+
sub_block_ops = []
231+
sub_block = pdesc.block(parent_id)
232+
block_op_size = sub_block.op_size()
233+
for i in range(block_op_size):
234+
sub_block_ops.append(sub_block.op(i))
230235

231-
# Get while/while_grad block ops
232-
for parent_id, grad_id in while_block_id_pair:
233-
while_block_ops = []
234-
while_block = pdesc.block(parent_id)
235-
while_block_op_size = while_block.op_size()
236-
for i in range(while_block_op_size):
237-
while_block_ops.append(while_block.op(i))
236+
grad_sub_block = pdesc.block(grad_id)
237+
grad_sub_block_op_size = grad_sub_block.op_size()
238+
for i in range(grad_sub_block_op_size):
239+
sub_block_ops.append(grad_sub_block.op(i))
238240

239-
while_grad_block = pdesc.block(grad_id)
240-
while_grad_block_op_size = while_grad_block.op_size()
241-
for i in range(while_grad_block_op_size):
242-
while_block_ops.append(while_grad_block.op(i))
241+
sub_op_output = set()
242+
sub_op_output.update(sub_op_dict[parent_id].output_arg_names())
243+
sub_op_output.update(sub_op_dict[grad_id].output_arg_names())
244+
ops_list.append((sub_block_ops, block_op_size, sub_op_output))
243245

244-
while_op_output = set()
245-
while_op_output.update(while_op_dict[parent_id].output_arg_names())
246-
while_op_output.update(while_op_dict[grad_id].output_arg_names())
246+
# Process rest fwd_op block ops
247+
for parent_id in sub_block_ids:
248+
sub_block_ops = []
249+
sub_block = pdesc.block(parent_id)
250+
sub_block_op_size = sub_block.op_size()
251+
for i in range(sub_block_op_size):
252+
sub_block_ops.append(sub_block.op(i))
253+
sub_op_output = set()
254+
sub_op_output.update(sub_op_dict[parent_id].output_arg_names())
255+
ops_list.append((sub_block_ops, sub_block_op_size, sub_op_output))
256+
return ops_list
247257

248-
ops_list.append((while_block_ops, while_block_op_size, while_op_output))
249258

250-
# Process rest while block ops
251-
for parent_id in while_sub_block_ids:
252-
while_block_ops = []
253-
while_block = pdesc.block(parent_id)
254-
while_block_op_size = while_block.op_size()
255-
for i in range(while_block_op_size):
256-
while_block_ops.append(while_block.op(i))
259+
def _get_cfgs(input_program):
260+
ops_list = []
261+
pdesc = input_program.get_desc()
262+
block_desc = pdesc.block(0)
263+
op_size = block_desc.op_size()
264+
# Get global block ops
265+
ops_list.append(
266+
([block_desc.op(i) for i in range(op_size)], op_size, set()))
257267

258-
while_op_output = set()
259-
while_op_output.update(while_op_dict[parent_id].output_arg_names())
268+
sub_block_pair = [("while", "while_grad"), ("parallel_do",
269+
"parallel_do_grad")]
260270

261-
ops_list.append((while_block_ops, while_block_op_size, while_op_output))
271+
ops_list.extend(_process_sub_block_pair(pdesc, sub_block_pair))
262272

263273
cfgs = [
264274
ControlFlowGraph(input_program, ops, forward_num, skip_opt)
@@ -268,6 +278,6 @@ def get_cfgs(input_program):
268278

269279

270280
def memory_optimize(input_program):
271-
cfgs = get_cfgs(input_program)
281+
cfgs = _get_cfgs(input_program)
272282
for cfg in cfgs:
273283
cfg.memory_optimize()

python/paddle/v2/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,21 @@
2424
fluid.default_startup_program().random_seed = 111
2525

2626
x = fluid.layers.data(name='x', shape=[13], dtype='float32')
27-
28-
y_predict = fluid.layers.fc(input=x, size=1, act=None)
29-
3027
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
3128

32-
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
33-
avg_cost = fluid.layers.mean(x=cost)
29+
places = fluid.layers.get_places(device_count=2, device_type='CPU')
30+
pd = fluid.layers.ParallelDo(places)
31+
with pd.do():
32+
x_ = pd.read_input(x)
33+
y_ = pd.read_input(y)
34+
y_predict = fluid.layers.fc(input=x_, size=1, act=None)
35+
cost = fluid.layers.square_error_cost(input=y_predict, label=y_)
36+
avg_cost = fluid.layers.mean(x=cost)
37+
pd.write_output(avg_cost)
3438

35-
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1)
39+
cost = pd()
40+
avg_cost = fluid.layers.mean(x=cost)
41+
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01)
3642
sgd_optimizer.minimize(avg_cost)
3743

3844
fluid.memory_optimize(fluid.default_main_program())
@@ -65,6 +71,7 @@
6571

6672
if avg_loss_value[0] < 10.0:
6773
exit(0) # if avg cost less than 10.0, we think our code is good.
74+
print avg_loss_value[0]
6875
if math.isnan(float(avg_loss_value)):
6976
sys.exit("got NaN loss, training failed.")
7077
exit(1)

0 commit comments

Comments
 (0)