Skip to content

Commit a5ef6bf

Browse files
authored
Merge pull request #16867 from velconia/local_rel_1_4_dygraph_untrack_op
Imperative untrack op in eval mode
2 parents 044b7fc + dc19c6f commit a5ef6bf

File tree

4 files changed

+289
-34
lines changed

4 files changed

+289
-34
lines changed

python/paddle/fluid/dygraph/layers.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,12 @@ def __init__(self, name_scope, dtype=core.VarDesc.VarType.FP32):
4848

4949
self._helper = LayerObjectHelper(self._full_name)
5050

51+
def train(self):
52+
framework._dygraph_tracer()._train_mode()
53+
54+
def eval(self):
55+
framework._dygraph_tracer()._eval_mode()
56+
5157
def full_name(self):
5258
"""Full name for this layers.
5359
@@ -254,6 +260,12 @@ class PyLayer(core.PyLayer):
254260
def __init__(self):
255261
super(PyLayer, self).__init__()
256262

263+
def train(self):
264+
framework._dygraph_tracer()._train_mode()
265+
266+
def eval(self):
267+
framework._dygraph_tracer()._eval_mode()
268+
257269
@classmethod
258270
def _do_forward(cls, inputs):
259271
return cls._to_tuple(cls.forward(inputs))

python/paddle/fluid/dygraph/tracer.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,9 @@
2424

2525

2626
def release_op(op):
27-
del framework._dygraph_tracer()._ops[op._trace_id]
27+
del framework._dygraph_tracer()._ops[op._trace_id].inputs
28+
del framework._dygraph_tracer()._ops[op._trace_id].outputs
29+
del framework._dygraph_tracer()._ops[op._trace_id].backward_refs
2830

2931

3032
class Tracer(core.Tracer):
@@ -38,6 +40,7 @@ def __init__(self, block):
3840
self._ops = defaultdict()
3941
self._vars = defaultdict()
4042
self._trace_id = 0
43+
self._train_mode = True
4144

4245
def trace_var(self, name, var):
4346
self._vars[name] = var
@@ -46,15 +49,57 @@ def all_parameters(self):
4649
return list((item for name, item in six.iteritems(self._vars)
4750
if isinstance(item, framework.Parameter)))
4851

49-
def trace_op(self, op, stop_gradient=False):
52+
def trace_op(self, op, inputs, outputs, stop_gradient=False):
53+
# TODO(minqiyang): remove this line after we take apart all
54+
# backward grads and forward variables
55+
if self._train_mode:
56+
op.inputs = inputs
57+
inps = defaultdict(list)
58+
for k, vars in six.iteritems(inputs):
59+
if isinstance(vars, framework.Variable):
60+
inps[k].append(vars._ivar)
61+
elif isinstance(vars, list) or isinstance(vars, tuple):
62+
for var in vars:
63+
inps[k].append(var._ivar)
64+
65+
op.outputs = outputs
66+
outs = defaultdict(list)
67+
for k, vars in six.iteritems(outputs):
68+
if isinstance(vars, framework.Variable):
69+
outs[k].append(vars._ivar)
70+
elif isinstance(vars, list) or isinstance(vars, tuple):
71+
for var in vars:
72+
outs[k].append(var._ivar)
73+
else:
74+
inps = defaultdict(list)
75+
for k, vars in six.iteritems(inputs):
76+
if isinstance(vars, framework.Variable):
77+
op.previous_ops.append(vars.op)
78+
inps[k].append(vars._ivar)
79+
elif isinstance(vars, list) or isinstance(vars, tuple):
80+
for var in vars:
81+
op.previous_ops.append(var.op)
82+
inps[k].append(var._ivar)
83+
84+
op.outputs = outputs
85+
outs = defaultdict(list)
86+
for k, vars in six.iteritems(outputs):
87+
if isinstance(vars, framework.Variable):
88+
vars.op = op
89+
outs[k].append(vars._ivar)
90+
elif isinstance(vars, list) or isinstance(vars, tuple):
91+
for var in vars:
92+
var.op = op
93+
outs[k].append(var._ivar)
94+
5095
# record op's trace id
5196
op.iop._trace_id = self._trace_id
5297

53-
backward_refs = self.trace(op.iop, op.inputs, op.outputs, op.attrs,
98+
backward_refs = self.trace(op.iop, inps, outs, op.attrs,
5499
framework._current_expected_place(),
55100
stop_gradient)
56101

57-
if not stop_gradient:
102+
if not stop_gradient and self._train_mode:
58103
self._trace_id += 1
59104
self._ops[op.iop._trace_id] = op
60105

@@ -65,10 +110,16 @@ def trace_op(self, op, stop_gradient=False):
65110
# TODO(minqiyang): remove all inputs and outputs after separate
66111
# var and grad
67112
op.backward_refs = defaultdict(list)
68-
for k, v in six.iteritems(op.inputs):
113+
for k, v in six.iteritems(inputs):
69114
if k in backward_refs:
70-
op.backward_refs[k] = op.inputs[k]
115+
op.backward_refs[k] = inputs[k]
71116

72-
for k, v in six.iteritems(op.outputs):
117+
for k, v in six.iteritems(outputs):
73118
if k in backward_refs:
74-
op.backward_refs[k] = op.outputs[k]
119+
op.backward_refs[k] = outputs[k]
120+
121+
def _train_mode(self):
122+
self._train_mode = True
123+
124+
def _eval_mode(self):
125+
self._train_mode = False

python/paddle/fluid/framework.py

Lines changed: 17 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ def __init__(self,
411411
if persistable else False)
412412
if persistable:
413413
_dygraph_tracer().trace_var(name, self)
414+
self.op = None
414415
else:
415416
self.error_clip = error_clip
416417

@@ -939,24 +940,7 @@ def __init__(self,
939940
raise ValueError(
940941
"`type` to initialized an Operator can not be None.")
941942
self.iop = core.OpBase(type)
942-
943-
# TODO(minqiyang): remove these lines after we take apart all
944-
# backward grads and forward variables
945-
self.inputs = defaultdict(list)
946-
if inputs is not None:
947-
for k, v in six.iteritems(inputs):
948-
if isinstance(v, Variable):
949-
self.inputs[k].append(v._ivar)
950-
elif isinstance(v, list) or isinstance(v, tuple):
951-
self.inputs[k].extend([var._ivar for var in v])
952-
953-
self.outputs = defaultdict(list)
954-
if outputs is not None:
955-
for k, v in six.iteritems(outputs):
956-
if isinstance(v, Variable):
957-
self.outputs[k].append(v._ivar)
958-
elif isinstance(v, list) or isinstance(v, tuple):
959-
self.outputs[k].extend([var._ivar for var in v])
943+
self.previous_ops = []
960944

961945
self.attrs = attrs if attrs else {}
962946
else:
@@ -1647,15 +1631,18 @@ def append_op(self, *args, **kwargs):
16471631
block=self,
16481632
desc=None,
16491633
type=kwargs.get("type", None),
1650-
inputs=kwargs.get("inputs", None),
1651-
outputs=kwargs.get("outputs", None),
1652-
attrs=kwargs.get("attrs", None))
1634+
inputs=None,
1635+
outputs=None,
1636+
attrs=kwargs.get("attrs", {}))
16531637

16541638
# record ops in tracer rather than blocks
16551639
#
16561640
# TODO(minqiyang): add op stop_gradient support in static mode too.
16571641
# currently, we only support stop_gradient in dygraph mode.
1658-
_dygraph_tracer().trace_op(op, kwargs.get("stop_gradient", False))
1642+
_dygraph_tracer().trace_op(op,
1643+
kwargs.get("inputs", {}),
1644+
kwargs.get("outputs", {}),
1645+
kwargs.get("stop_gradient", False))
16591646
else:
16601647
op_desc = self.desc.append_op()
16611648
op = Operator(
@@ -1719,10 +1706,14 @@ def _prepend_op(self, *args, **kwargs):
17191706
self,
17201707
None,
17211708
type=kwargs.get("type", None),
1722-
inputs=kwargs.get("inputs", None),
1723-
outputs=kwargs.get("outputs", None),
1724-
attrs=kwargs.get("attrs", None))
1725-
_dygraph_tracer().trace_op(op, kwargs.get("stop_gradient", False))
1709+
inputs=None,
1710+
outputs=None,
1711+
attrs=kwargs.get("attrs", {}))
1712+
1713+
_dygraph_tracer().trace_op(op,
1714+
kwargs.get("inputs", {}),
1715+
kwargs.get("outputs", {}),
1716+
kwargs.get("stop_gradient", False))
17261717
else:
17271718
op_desc = self.desc._prepend_op()
17281719
op = Operator(

0 commit comments

Comments
 (0)