Skip to content

Commit 9571045

Browse files
authored
Merge pull request #9600 from luotao1/sync_with_cpp
refine sync_with_cpp when remove ops or remove vars
2 parents 49313d4 + 103407a commit 9571045

File tree

2 files changed

+37
-4
lines changed

2 files changed

+37
-4
lines changed

python/paddle/fluid/framework.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -847,6 +847,11 @@ def sync_with_cpp(self):
847847
if not self.has_var(var.name()):
848848
self.create_var(name=var.name(), desc=var, type=var.type())
849849

850+
# sync variables removed from c++ end
851+
for var in self.vars.keys():
852+
if not self.desc.find_var(var):
853+
self.vars.pop(var)
854+
850855
# sync operators from cpp
851856
ops_in_cpp = []
852857
for op_idx in range(0, self.desc.op_size()):
@@ -881,6 +886,19 @@ def sync_with_cpp(self):
881886
op = Operator(self, op_desc)
882887
self.ops.append(op)
883888

889+
# sync ops removed from c++ end
890+
if end_index != -1 and end_index < len(self.ops):
891+
ops_in_cpp_index = 0
892+
ops_in_python_index = 0
893+
while ops_in_python_index < len(
894+
self.ops) and ops_in_cpp_index < len(ops_in_cpp):
895+
if self.ops[ops_in_python_index].desc != ops_in_cpp[
896+
ops_in_cpp_index]:
897+
del self.ops[ops_in_python_index]
898+
else:
899+
ops_in_cpp_index += 1
900+
ops_in_python_index += 1
901+
884902
assert len(self.ops) == len(ops_in_cpp)
885903
for index in range(len(self.ops)):
886904
assert self.ops[index].desc == ops_in_cpp[index]

python/paddle/fluid/tests/unittests/test_protobuf_descs.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import unittest
1616
import paddle.fluid.core as core
17+
from paddle.fluid.framework import Program
1718

1819

1920
class TestOpDesc(unittest.TestCase):
@@ -187,32 +188,46 @@ def test_add_op(self):
187188
self.assertEqual(all_ops, [op0, op1, op2])
188189

189190
def test_remove_op(self):
190-
prog = core.ProgramDesc()
191+
program = Program()
192+
prog = program.desc
191193
self.assertIsNotNone(prog)
192194
block = prog.block(0)
193195
self.assertIsNotNone(block)
196+
197+
op0 = block.append_op()
194198
op1 = block.append_op()
195199
op2 = block.append_op()
200+
op0.set_type("test")
201+
op1.set_type("test")
202+
op2.set_type("test")
203+
204+
var0 = block.var("var0")
196205
var1 = block.var("var1")
197206
var2 = block.var("var2")
198207
var3 = block.var("var3")
199208
var4 = block.var("var4")
200209
var5 = block.var("var5")
210+
211+
op0.set_input("X", ["var0"])
212+
op0.set_output("Y", ["var0"])
201213
op1.set_input("X", ["var1", "var2"])
202214
op1.set_output("Y", ["var3", "var4"])
203215
op2.set_input("X", ["var1"])
204216
op2.set_output("Y", ["var4", "var5"])
205217

218+
program.sync_with_cpp()
219+
206220
# remove op1, its input var2 and output var3 will be removed at the same time,
207221
# but its input var1 and output var4 will not be removed since they are used for op2.
208-
block.remove_op(0, 1)
222+
block.remove_op(1, 2)
223+
program.sync_with_cpp()
209224

210225
all_ops = []
211226
for idx in xrange(0, block.op_size()):
212227
all_ops.append(block.op(idx))
213-
self.assertEqual(all_ops, [op2])
228+
self.assertEqual(all_ops, [op0, op2])
214229
all_vars = block.all_vars()
215-
self.assertEqual(set(all_vars), {var1, var4, var5})
230+
self.assertEqual(set(all_vars), {var0, var1, var4, var5})
216231

217232

218233
if __name__ == '__main__':

0 commit comments

Comments
 (0)