Skip to content

Commit 7c73a68

Browse files
authored
test=release/1.5, cherry-pick hide not_support for dygraph (#18528)
* test=release/1.5, cherry-pick hide not_support for dygraph * test=release/1.5, cherry-pick hide not_support for dygraph
1 parent 856536b commit 7c73a68

File tree

3 files changed

+22
-4
lines changed

3 files changed

+22
-4
lines changed

python/paddle/fluid/clip.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from . import layers
2222
from . import framework
2323
from . import core
24-
from .dygraph import not_support
24+
from .dygraph.base import _not_support
2525

2626
__all__ = [
2727
'ErrorClipByValue',
@@ -336,7 +336,7 @@ def _create_operators(self, param, grad):
336336
return param, new_grad
337337

338338

339-
@not_support
339+
@_not_support
340340
def set_gradient_clip(clip, param_list=None, program=None):
341341
"""
342342
To specify parameters that require gradient clip.

python/paddle/fluid/dygraph/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
__all__ = [
2525
'enabled',
2626
'no_grad',
27-
'not_support',
2827
'guard',
2928
'to_variable',
3029
]
@@ -91,7 +90,7 @@ def __impl__(*args, **kwargs):
9190

9291

9392
no_grad = wrap_decorator(_no_grad_)
94-
not_support = wrap_decorator(_dygraph_not_support_)
93+
_not_support = wrap_decorator(_dygraph_not_support_)
9594

9695

9796
@signature_safe_contextmanager

python/paddle/fluid/tests/unittests/test_imperative_decorator.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import paddle.fluid as fluid
1616
import paddle.fluid.framework as framework
1717
import unittest
18+
from test_imperative_base import new_program_scope
1819

1920

2021
class TestTracerMode(unittest.TestCase):
@@ -29,6 +30,18 @@ def no_grad_func(self, a):
2930
self.assertEqual(self.tracer._train_mode, False)
3031
return a
3132

33+
@fluid.dygraph.base._not_support
34+
def not_support_func(self):
35+
return True
36+
37+
def check_not_support_rlt(self, ans):
38+
try:
39+
rlt = self.not_support_func()
40+
except AssertionError:
41+
rlt = False
42+
finally:
43+
self.assertEqual(rlt, ans)
44+
3245
def test_main(self):
3346
with fluid.dygraph.guard():
3447
self.tracer = framework._dygraph_tracer()
@@ -38,6 +51,12 @@ def test_main(self):
3851

3952
self.assertEqual(self.tracer._train_mode, self.init_mode)
4053

54+
with fluid.dygraph.guard():
55+
self.check_not_support_rlt(False)
56+
57+
with new_program_scope():
58+
self.check_not_support_rlt(True)
59+
4160

4261
class TestTracerMode2(TestTracerMode):
4362
def setUp(self):

0 commit comments

Comments
 (0)