Skip to content

Commit 99f45f1

Browse files
authored
[Bug Fix] Fix test_compiler when there is None in result. (PaddlePaddle#227)
* Fix test_compiler when there is None in result. * Fix validate for single return.
1 parent b16b377 commit 99f45f1

File tree

2 files changed

+12
-4
lines changed

2 files changed

+12
-4
lines changed

graph_net/paddle/test_compiler.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,15 +124,23 @@ def test_single_model(args):
124124
compiled_duration_box = DurationBox(-1)
125125
with naive_timer(compiled_duration_box, synchronizer_func):
126126
compiled_out = compiled_model(**input_dict)
127+
127128
if isinstance(expected_out, paddle.Tensor):
128129
expected_out = [expected_out]
129130
compiled_out = [compiled_out]
130131
if isinstance(expected_out, list) or isinstance(expected_out, tuple):
132+
for a, b in zip(expected_out, compiled_out):
133+
if (a is None and b is not None) or (a is not None and b is None):
134+
raise ValueError("Both expected_out and compiled_out must be not None.")
131135
expected_out = [
132-
regular_item(item) for item in expected_out if np.array(item).size != 0
136+
regular_item(item)
137+
for item in expected_out
138+
if item is not None and np.array(item).size != 0
133139
]
134140
compiled_out = [
135-
regular_item(item) for item in compiled_out if np.array(item).size != 0
141+
regular_item(item)
142+
for item in compiled_out
143+
if item is not None and np.array(item).size != 0
136144
]
137145
else:
138146
raise ValueError("Illegal return value.")

graph_net/paddle/validate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,9 @@ def main(args):
6666
params.update(inputs)
6767
state_dict = {k: utils.replay_tensor(v) for k, v in params.items()}
6868

69-
y = model(**state_dict)[0]
69+
y = model(**state_dict)
7070

71-
print(np.argmin(y), np.argmax(y))
71+
# print(np.argmin(y), np.argmax(y))
7272
if isinstance(y, paddle.Tensor):
7373
print(y.shape)
7474
elif isinstance(y, list) or isinstance(y, tuple):

0 commit comments

Comments
 (0)