Skip to content

Commit 77428e8

Browse files
authored
fix py_func bug when out is list and add unittest case(#22596)
1 parent 370fdaa commit 77428e8

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

python/paddle/fluid/layers/nn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12598,7 +12598,9 @@ def py_func_demo():
1259812598
out_list = [out]
1259912599
elif isinstance(out, tuple):
1260012600
out_list = list(out)
12601-
elif not isinstance(x, (list, tuple, Variable)):
12601+
elif isinstance(out, list):
12602+
out_list = out
12603+
else:
1260212604
raise TypeError(
1260312605
'Output must be Variable/list(Variable)/tuple(Variable)')
1260412606

python/paddle/fluid/tests/unittests/test_py_func_op.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ def dummy_func_with_no_output(x):
3434
pass
3535

3636

37+
def dummy_func_with_multi_input_output(x, y):
38+
return np.array(x), np.array(y)
39+
40+
3741
def tanh(x):
3842
return np.tanh(x)
3943

@@ -109,6 +113,24 @@ def simple_fc_net(img, label, use_py_func_op):
109113
loss += dummy_var
110114
fluid.layers.py_func(func=dummy_func_with_no_output, x=loss, out=None)
111115

116+
loss_out = fluid.default_main_program().current_block().create_var(
117+
dtype='float32', shape=[-1, 1])
118+
dummy_var_out = fluid.default_main_program().current_block().create_var(
119+
dtype='float32', shape=[1])
120+
fluid.layers.py_func(
121+
func=dummy_func_with_multi_input_output,
122+
x=(loss, dummy_var),
123+
out=(loss_out, dummy_var_out))
124+
assert loss == loss_out and dummy_var == dummy_var_out, \
125+
"py_func failed with multi input and output"
126+
127+
fluid.layers.py_func(
128+
func=dummy_func_with_multi_input_output,
129+
x=[loss, dummy_var],
130+
out=[loss_out, dummy_var_out])
131+
assert loss == loss_out and dummy_var == dummy_var_out, \
132+
"py_func failed with multi input and output"
133+
112134
loss = fluid.layers.mean(loss)
113135
return loss
114136

0 commit comments

Comments
 (0)