Skip to content

Commit a07cf18

Browse files
committed
Update paddle test compiler
1 parent c3d768a commit a07cf18

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

graph_net/paddle/test_compiler.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,19 @@ def test_single_model(args):
114114
compiled_duration_box = DurationBox(-1)
115115
with naive_timer(compiled_duration_box, synchronizer_func):
116116
compiled_out = compiled_model(**input_dict)
117-
expected_out = expected_out.numpy()
118-
compiled_out = compiled_out.numpy()
117+
118+
expected_out_list = (
119+
expected_out if isinstance(expected_out, (list, tuple)) else [expected_out]
120+
)
121+
compiled_out_list = (
122+
compiled_out if isinstance(compiled_out, (list, tuple)) else [compiled_out]
123+
)
124+
125+
processed_expected_out = [t.numpy() for t in expected_out_list]
126+
processed_compiled_out = [t.numpy() for t in compiled_out_list]
119127

120128
def print_cmp(key, func, **kwargs):
121-
cmp_ret = func(expected_out, compiled_out, **kwargs)
129+
cmp_ret = func(processed_expected_out, processed_compiled_out, **kwargs)
122130
print(
123131
f"{args.log_prompt} {key} model_path:{args.model_path} {cmp_ret}",
124132
file=sys.stderr,

0 commit comments

Comments
 (0)