Skip to content

Commit eed736d

Browse files
authored
[Dy2stat]Allow users to switch eval/train mode when using @to_static to decorate a function (#37383) (#37432)
本PR之前使用@to_static装饰一个单独的function时,对于生成的Program无法切换train/eval模式,只能运行在train模式下。这也就导致动转静后用户多次调用function显存会一直增长。 本PR之后,使用@to_static装饰一个单独的function时,可以通过function.train()或者function.eval()的方式来切换train/eval模式。
1 parent 2778fcd commit eed736d

File tree

2 files changed

+68
-0
lines changed

2 files changed

+68
-0
lines changed

python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,25 @@ def __init__(self, function, input_spec=None, **kwargs):
261261
# Note: Hold a reference to ProgramTranslator for switching `enable_to_static`.
262262
self._program_trans = ProgramTranslator()
263263
self._kwargs = kwargs
264+
self._training = True
265+
266+
def train(self):
267+
if isinstance(self._class_instance,
268+
layers.Layer) and self._class_instance.training == False:
269+
raise RuntimeError(
270+
"Failed to switch train mode. {} is a Layer's method, "
271+
"please use Layer.train() to switch train mode.".format(
272+
self.dygraph_function))
273+
self._training = True
274+
275+
def eval(self):
276+
if isinstance(self._class_instance,
277+
layers.Layer) and self._class_instance.training == True:
278+
raise RuntimeError(
279+
"Failed to switch eval mode. {} is a Layer's method, "
280+
"please use Layer.eval() to switch eval mode.".format(
281+
self.dygraph_function))
282+
self._training = False
264283

265284
def __get__(self, instance, owner):
266285
"""
@@ -340,6 +359,8 @@ def __call__(self, *args, **kwargs):
340359
# 3. synchronize self.training attribute.
341360
if isinstance(self._class_instance, layers.Layer):
342361
partial_program_layer.training = self._class_instance.training
362+
else:
363+
partial_program_layer.training = self._training
343364

344365
# 4. return outputs.
345366
try:

python/paddle/fluid/tests/unittests/dygraph_to_static/test_program_translator.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,5 +297,52 @@ def test_raise_error(self):
297297
self.program_translator.get_program(net.forward, self.x)
298298

299299

300+
class SwitchModeNet(paddle.nn.Layer):
301+
def __init__(self):
302+
super(SwitchModeNet, self).__init__()
303+
304+
@paddle.jit.to_static
305+
def forward(self, x):
306+
return x + 1
307+
308+
@paddle.jit.to_static
309+
def foo(self):
310+
return True
311+
312+
313+
@paddle.jit.to_static
314+
def switch_mode_funciton():
315+
return True
316+
317+
318+
class TestFunctionTrainEvalMode(unittest.TestCase):
319+
def test_switch_mode(self):
320+
paddle.disable_static()
321+
switch_mode_funciton.eval()
322+
switch_mode_funciton()
323+
self.assertEqual(switch_mode_funciton._training, False)
324+
_, partial_layer = switch_mode_funciton.program_cache.last()[-1]
325+
self.assertEqual(partial_layer.training, False)
326+
327+
switch_mode_funciton.train()
328+
switch_mode_funciton()
329+
self.assertEqual(switch_mode_funciton._training, True)
330+
_, partial_layer = switch_mode_funciton.program_cache.last()[-1]
331+
self.assertEqual(partial_layer.training, True)
332+
333+
def test_raise_error(self):
334+
paddle.disable_static()
335+
net = SwitchModeNet()
336+
337+
self.assertEqual(net.training, True)
338+
with self.assertRaises(RuntimeError):
339+
net.forward.eval()
340+
341+
net.eval()
342+
self.assertEqual(net.training, False)
343+
with self.assertRaises(RuntimeError):
344+
net.foo.train()
345+
346+
300347
if __name__ == '__main__':
301348
unittest.main()

0 commit comments

Comments
 (0)