Skip to content

Commit db0c1ea

Browse files
authored
[Dy2stat] Fix function lookup bug in convert_call (#24567)
* fix convert call globals_funcs test=develop * add import statement test=develop
1 parent 217ca77 commit db0c1ea

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,10 @@ def dyfunc(x):
102102
if func.__name__ == '<lambda>':
103103
return func
104104
try:
105-
if func in func.__globals__.values():
105+
global_funcs = set([
106+
fn for fn in func.__globals__.values() if inspect.isfunction(fn)
107+
])
108+
if func in global_funcs:
106109
converted_call = to_static_func(func)
107110
func_self = getattr(func, '__self__', None)
108111
except AttributeError:

python/paddle/fluid/dygraph/dygraph_to_static/program_translator.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -368,9 +368,10 @@ def func(x):
368368
369369
prog_trans = fluid.dygraph.ProgramTranslator()
370370
371-
x = np.ones([1, 2])
372-
x_v = prog_trans.get_output(func, x)
373-
print(x_v.numpy()) # [[0. 0.]]
371+
with fluid.dygraph.guard():
372+
x = np.ones([1, 2])
373+
x_v = prog_trans.get_output(func, x)
374+
print(x_v.numpy()) # [[0. 0.]]
374375
375376
"""
376377
assert callable(
@@ -472,7 +473,7 @@ def func(x):
472473
x = np.ones([1, 2])
473474
main_prog, start_prog, inputs, outputs = prog_trans.get_program(func, x)
474475
print([i.name for i in inputs])
475-
# ['x_0'] the feed input variable name representing x
476+
# ['feed_0'] the feed input variable name representing x
476477
print([o.name for o in outputs])
477478
# ['_generated_var_4'] the fetch output variable name representing x_v
478479
@@ -573,6 +574,7 @@ def save_inference_model(self, dirname, feed=None, fetch=None):
573574
import numpy as np
574575
import paddle.fluid as fluid
575576
from paddle.fluid.dygraph import Linear
577+
from paddle.fluid.dygraph import declarative
576578
from paddle.fluid.dygraph import ProgramTranslator
577579
578580
class SimpleNet(fluid.dygraph.Layer):

0 commit comments

Comments
 (0)