Skip to content

Commit 12c51f5

Browse files
authored
[Dy2static]Fix paddle prefix in is_paddle_api (#30569) (#30594)
[Dy2static]Fix paddle prefix in is_paddle_api (#30569) cherry-pick #30569
1 parent 3317cf0 commit 12c51f5

File tree

3 files changed

+27
-9
lines changed

3 files changed

+27
-9
lines changed

python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from paddle.fluid.dygraph.dygraph_to_static.program_translator import StaticFunction
3232
from paddle.fluid.dygraph.dygraph_to_static.program_translator import convert_to_static
3333
from paddle.fluid.dygraph.dygraph_to_static.program_translator import unwrap_decorators
34+
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_func
3435
from paddle.fluid.dygraph.layers import Layer
3536

3637
__all__ = ["convert_call"]
@@ -74,11 +75,6 @@ def is_builtin_len(func):
7475
return False
7576

7677

77-
def is_paddle_func(func):
78-
m = inspect.getmodule(func)
79-
return m is not None and m.__name__.startswith("paddle")
80-
81-
8278
def is_unsupported(func):
8379
"""
8480
Checks whether the func is supported by dygraph to static graph.

python/paddle/fluid/dygraph/dygraph_to_static/utils.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,12 @@
3030
from paddle.fluid import unique_name
3131
from paddle.fluid.data_feeder import convert_dtype
3232

33+
# Note(Aurelius): Do not forget the dot `.` to distinguish other
34+
# module such as paddlenlp.
35+
PADDLE_MODULE_PREFIX = 'paddle.'
36+
DYGRAPH_MODULE_PREFIX = 'paddle.fluid.dygraph'
37+
DYGRAPH_TO_STATIC_MODULE_PREFIX = 'paddle.fluid.dygraph.dygraph_to_static'
38+
3339

3440
class BaseNodeVisitor(gast.NodeVisitor):
3541
"""
@@ -191,16 +197,21 @@ def is_api_in_module(node, module_prefix):
191197
def is_dygraph_api(node):
192198

193199
# Note: A api in module dygraph_to_static is not a real dygraph api.
194-
if is_api_in_module(node, "paddle.fluid.dygraph.dygraph_to_static"):
200+
if is_api_in_module(node, DYGRAPH_TO_STATIC_MODULE_PREFIX):
195201
return False
196202

197203
# TODO(liym27): A better way to determine whether it is a dygraph api.
198204
# Consider the decorator @dygraph_only
199-
return is_api_in_module(node, "paddle.fluid.dygraph")
205+
return is_api_in_module(node, DYGRAPH_MODULE_PREFIX)
200206

201207

202208
def is_paddle_api(node):
203-
return is_api_in_module(node, "paddle")
209+
return is_api_in_module(node, PADDLE_MODULE_PREFIX)
210+
211+
212+
def is_paddle_func(func):
213+
m = inspect.getmodule(func)
214+
return m is not None and m.__name__.startswith(PADDLE_MODULE_PREFIX)
204215

205216

206217
# Is numpy_api cannot reuse is_api_in_module because of numpy module problem
@@ -1235,7 +1246,7 @@ def input_specs_compatible(src_input_specs, desired_input_specs):
12351246
len_specs = len(src_input_specs)
12361247
if len_specs != len(desired_input_specs):
12371248
# NOTE(chenweihang): if the input_spec of jit.save is a subset of
1238-
# input_spec of to_static, also compatible
1249+
# input_spec of to_static, also compatible
12391250
for spec in src_input_specs:
12401251
if spec not in desired_input_specs:
12411252
return False

python/paddle/fluid/tests/unittests/dygraph_to_static/test_utils.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414

1515
from __future__ import print_function
1616

17+
import types
1718
import unittest
1819

1920
from paddle.fluid.dygraph.dygraph_to_static import ProgramTranslator
2021
from paddle.fluid.dygraph.dygraph_to_static.utils import index_in_list
22+
from paddle.fluid.dygraph.dygraph_to_static.utils import is_paddle_func
2123

2224
from test_program_translator import get_source_code
2325

@@ -61,5 +63,14 @@ def test_code(self):
6163
self.assertEqual(answer, code)
6264

6365

66+
class TestIsPaddle(unittest.TestCase):
67+
def fake_module(self):
68+
return types.ModuleType('paddlenlp')
69+
70+
def test_func(self):
71+
m = self.fake_module()
72+
self.assertFalse(is_paddle_func(m))
73+
74+
6475
if __name__ == '__main__':
6576
unittest.main()

0 commit comments

Comments
 (0)