Skip to content

Commit b3fb4bf

Browse files
authored
[Bug Fix] handle return value of list type (#216)
1 parent 89d991b commit b3fb4bf

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

graph_net/paddle/validate.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import graph_net
1313
import os
1414
import re
15+
import paddle
1516

1617

1718
def load_class_from_file(file_path: str, class_name: str):
@@ -68,7 +69,12 @@ def main(args):
6869
y = model(**state_dict)[0]
6970

7071
print(np.argmin(y), np.argmax(y))
71-
print(y.shape)
72+
if isinstance(y, paddle.Tensor):
73+
print(y.shape)
74+
elif isinstance(y, list) or isinstance(y, tuple):
75+
print(y[0].shape)
76+
else:
77+
raise ValueError("Illegal Return Value.")
7278

7379
if not args.no_check_redundancy:
7480
print("Check redundancy ...")

0 commit comments

Comments
 (0)