Skip to content

Commit e4e5bad

Browse files
authored
Merge pull request #16908 from velconia/local_rel_1_4_dygraph_untrack_op
imperative fix train mode
2 parents 68b9332 + 40e28de commit e4e5bad

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

python/paddle/fluid/dygraph/layers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,10 @@ def __init__(self, name_scope, dtype=core.VarDesc.VarType.FP32):
4949
self._helper = LayerObjectHelper(self._full_name)
5050

5151
def train(self):
52-
framework._dygraph_tracer()._train_mode()
52+
framework._dygraph_tracer().train_mode()
5353

5454
def eval(self):
55-
framework._dygraph_tracer()._eval_mode()
55+
framework._dygraph_tracer().eval_mode()
5656

5757
def full_name(self):
5858
"""Full name for this layers.
@@ -261,10 +261,10 @@ def __init__(self):
261261
super(PyLayer, self).__init__()
262262

263263
def train(self):
264-
framework._dygraph_tracer()._train_mode()
264+
framework._dygraph_tracer().train_mode()
265265

266266
def eval(self):
267-
framework._dygraph_tracer()._eval_mode()
267+
framework._dygraph_tracer().eval_mode()
268268

269269
@classmethod
270270
def _do_forward(cls, inputs):

python/paddle/fluid/dygraph/tracer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,8 @@ def trace_op(self, op, inputs, outputs, stop_gradient=False):
118118
if k in backward_refs:
119119
op.backward_refs[k] = outputs[k]
120120

121-
def _train_mode(self):
121+
def train_mode(self):
122122
self._train_mode = True
123123

124-
def _eval_mode(self):
124+
def eval_mode(self):
125125
self._train_mode = False

python/paddle/fluid/tests/unittests/test_imperative_mnist.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def test_mnist_float32(self):
117117
train_reader = paddle.batch(
118118
paddle.dataset.mnist.train(), batch_size=128, drop_last=True)
119119

120+
mnist.train()
120121
dy_param_init_value = {}
121122
for epoch in range(epoch_num):
122123
for batch_id, data in enumerate(train_reader()):

0 commit comments

Comments
 (0)