Skip to content

Commit 23252a3

Browse files
authored
Merge pull request #12463 from jacquesqiao/remove-duplicated-init-op
avoid duplicated init op for one parameter
2 parents 93152b0 + 641a7fb commit 23252a3

File tree

3 files changed

+190
-125
lines changed

3 files changed

+190
-125
lines changed

python/paddle/fluid/framework.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1038,7 +1038,26 @@ def create_parameter(self, *args, **kwargs):
10381038
global_block = self.program.global_block()
10391039
param = Parameter(global_block, *args, **kwargs)
10401040
if 'initializer' in kwargs:
1041-
kwargs['initializer'](param, self)
1041+
1042+
def _is_inited_by(block, var):
1043+
init_ops = []
1044+
for op in block.ops:
1045+
if var.name in op.output_arg_names:
1046+
init_ops.append(op)
1047+
return init_ops
1048+
1049+
initializer = kwargs['initializer']
1050+
init_ops = _is_inited_by(global_block, param)
1051+
init_ops_len = len(init_ops)
1052+
if init_ops_len > 1:
1053+
raise RuntimeError("param " + param.name +
1054+
" is inited by multiple init ops " + str(
1055+
init_ops))
1056+
elif init_ops_len == 1:
1057+
#TODO already inited, do nothing, should log a warning
1058+
pass
1059+
else:
1060+
initializer(param, self)
10421061
return param
10431062

10441063
def append_op(self, *args, **kwargs):

python/paddle/fluid/tests/unittests/test_dist_transpiler.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,18 @@ def _transpiler_instance(self, config=None):
7373

7474
return self.transpiler
7575

76+
def transpiler_test_impl(self):
77+
pass
7678

77-
class TestBasicModel(TranspilerTest):
7879
def test_transpiler(self):
80+
main = fluid.Program()
81+
startup = fluid.Program()
82+
with fluid.program_guard(main, startup):
83+
self.transpiler_test_impl()
84+
85+
86+
class TestBasicModel(TranspilerTest):
87+
def transpiler_test_impl(self):
7988
pserver, startup = self.get_pserver(self.pserver1_ep)
8089
pserver2, startup2 = self.get_pserver(self.pserver2_ep)
8190

@@ -123,7 +132,7 @@ def test_transpiler(self):
123132

124133

125134
class TestBasicModelWithLargeBlockSize(TranspilerTest):
126-
def test_transpiler(self):
135+
def transpiler_test_impl(self):
127136
config = fluid.DistributeTranspilerConfig()
128137
config.min_block_size = 1048576
129138

@@ -148,7 +157,7 @@ def test_transpiler(self):
148157
["sum", "scale", "sgd"])
149158
# confirm startup program
150159
self.assertEqual([op.type for op in startup.global_block().ops],
151-
["fill_constant", "fill_constant", "fill_constant"])
160+
["fill_constant", "fill_constant"])
152161
# the variable #fc_w will be split into two blocks
153162
fc_w_var = startup2.global_block().var("fc_w")
154163
self.assertEqual(fc_w_var.shape, (1000L, 1000L))
@@ -177,7 +186,7 @@ class TestNoSliceVar(TranspilerTest):
177186
def setUp(self):
178187
super(TestNoSliceVar, self).setUp()
179188

180-
def test_transpiler(self):
189+
def transpiler_test_impl(self):
181190
config = fluid.DistributeTranspilerConfig()
182191
config.slice_var_up = False
183192

@@ -212,7 +221,7 @@ def net_conf(self):
212221
sgd_optimizer.minimize(avg_cost)
213222
return
214223

215-
def test_transpiler(self):
224+
def transpiler_test_impl(self):
216225
pserver, startup = self.get_pserver(self.pserver1_ep)
217226
trainer = self.get_trainer()
218227

@@ -242,7 +251,7 @@ def net_conf(self):
242251
sgd_optimizer.minimize(avg_cost)
243252
return
244253

245-
def test_transpiler(self):
254+
def transpiler_test_impl(self):
246255
pserver, startup = self.get_pserver(self.pserver1_ep)
247256
trainer = self.get_trainer()
248257

@@ -291,7 +300,7 @@ def net_conf(self):
291300
sgd_optimizer.minimize(avg_cost)
292301
return
293302

294-
def test_transpiler(self):
303+
def transpiler_test_impl(self):
295304
pserver, startup = self.get_pserver(self.pserver1_ep)
296305
trainer = self.get_trainer()
297306

@@ -326,7 +335,7 @@ def net_conf(self):
326335
sgd_optimizer.minimize(avg_cost)
327336
return
328337

329-
def test_transpiler(self):
338+
def transpiler_test_impl(self):
330339
pserver, startup = self.get_pserver(self.pserver1_ep)
331340
trainer = self.get_trainer()
332341

0 commit comments

Comments
 (0)