Skip to content

Commit f4a69d5

Browse files
authored
[Dy2Stat] Fix eval_if_exist_else_none bug (#31261) (#31277)
* fix eval_if_exist_else_none bug * fix typo * fix typo * fix test_op_num unittest
1 parent 52f7e77 commit f4a69d5

File tree

4 files changed

+72
-4
lines changed

4 files changed

+72
-4
lines changed

python/paddle/fluid/dygraph/dygraph_to_static/convert_operators.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -302,9 +302,19 @@ def convert_var_shape_simple(x):
302302
return x.shape
303303

304304

305-
def eval_if_exist_else_none(name):
305+
def eval_if_exist_else_none(name, local_symbol_table):
306+
"""
307+
Args:
308+
name([str]): Expression passed into `eval`.
309+
local_symbol_table(dict): Specified from `locals()`. DO NOT use `globals()`,
310+
it has a higher priority and will hide away variables
311+
from `locals()`.
312+
313+
Returns:
314+
Return the variable if found in local_symbol_table else None.
315+
"""
306316
try:
307-
return eval(name)
317+
return eval(name, local_symbol_table)
308318
except:
309319
return None
310320

python/paddle/fluid/dygraph/dygraph_to_static/tensor_shape_transformer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ def create_convert_shape_node(var_shape_node,
5858

5959

6060
def create_choose_shape_node(attr_shape_name, api_shape_name, slice_node=None):
61-
eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}')".format(
61+
# Note(Aurelius84): Add `locals()` to help `eval` to locate the variable correctly.
62+
eval_exist_func = "paddle.jit.dy2static.eval_if_exist_else_none('{}', locals())".format(
6263
api_shape_name)
6364
args = [attr_shape_name, eval_exist_func]
6465

python/paddle/fluid/tests/unittests/dygraph_to_static/test_convert_operators.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import numpy as np
1616
import paddle
1717
import unittest
18+
from paddle.jit.dy2static.convert_operators import eval_if_exist_else_none
1819

1920

2021
class CallNotExist(paddle.nn.Layer):
@@ -189,5 +190,61 @@ def test_negative_attr_shape(self):
189190
paddle.shape(x))
190191

191192

193+
class TestEvaIfExistElseNone(unittest.TestCase):
194+
def test_locals(self):
195+
x_shape = [1, 2, 3]
196+
self.assertEqual(eval_if_exist_else_none('x_shape', locals()), x_shape)
197+
198+
def test_globals(self):
199+
x_shape = [1, 2, 3]
200+
201+
def foo():
202+
x_shape = [2, 3, 4]
203+
self.assertEqual(
204+
eval_if_exist_else_none('x_shape', locals()), [2, 3, 4])
205+
206+
foo()
207+
208+
def test_invisible_of_func(self):
209+
x_shape = [1, 2, 3]
210+
211+
def foo():
212+
x_shape = [2, 3, 4]
213+
return x_shape
214+
215+
self.assertEqual(
216+
eval_if_exist_else_none('x_shape', locals()), [1, 2, 3])
217+
218+
def test_none(self):
219+
def foo():
220+
x_shape = [2, 3, 4]
221+
return x_shape
222+
223+
self.assertEqual(eval_if_exist_else_none('x_shape', locals()), None)
224+
225+
226+
class ShapeLayer(paddle.nn.Layer):
227+
def __init__(self):
228+
super(ShapeLayer, self).__init__()
229+
230+
@paddle.jit.to_static(input_spec=[paddle.static.InputSpec(shape=[None, 1])])
231+
def forward(self, x):
232+
x = paddle.reshape(x, [-1, x.shape[1]])
233+
bs = x.shape[0] # -1
234+
235+
# for trigger choos_shape_attr_or_api
236+
out = paddle.zeros([bs, 1], dtype='float32')
237+
return out
238+
239+
240+
class TestChooseShapeAttrOrApiWithLayer(unittest.TestCase):
241+
def test_tensor_shape(self):
242+
x = paddle.zeros(shape=[4, 1], dtype='float32')
243+
net = ShapeLayer()
244+
out = net(x)
245+
246+
self.assertTrue(np.array_equal(out.numpy(), x.numpy()))
247+
248+
192249
if __name__ == '__main__':
193250
unittest.main()

python/paddle/fluid/tests/unittests/dygraph_to_static/test_tensor_shape.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def _set_test_func(self):
484484
self.dygraph_func = dyfunc_with_if_1
485485

486486
def _set_expected_op_num(self):
487-
self.expected_op_num = 19
487+
self.expected_op_num = 28
488488
self.expected_shape_op_num = 4
489489
self.expected_slice_op_num = 2
490490

0 commit comments

Comments
 (0)