Skip to content

Commit 4daaf96

Browse files
committed
Use paddle api to check the result.
1 parent 6952bc8 commit 4daaf96

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

graph_net/paddle/test_compiler.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,8 @@ def get_compiled_model(args, model):
9090

9191

9292
def regular_item(item):
93-
if isinstance(item, paddle.Tensor) and (item.dtype == paddle.bfloat16):
94-
item = np.array(item.astype("float32"))
95-
else:
96-
item = np.array(item)
97-
if item.dtype == np.bool_:
93+
assert isinstance(item, paddle.Tensor)
94+
if item.dtype not in [paddle.float32, paddle.float64]:
9895
item = item.astype("float32")
9996
return item
10097

@@ -306,32 +303,34 @@ def print_cmp(key, func, **kwargs):
306303

307304
def get_cmp_equal(expected_out, compiled_out):
308305
return " ".join(
309-
str(int(np.sum(np.equal(a, b)))) for a, b in zip(expected_out, compiled_out)
306+
str(int(paddle.equal_all(a, b))) for a, b in zip(expected_out, compiled_out)
310307
)
311308

312309

313310
def get_cmp_all_close(expected_out, compiled_out, atol, rtol):
314311
return " ".join(
315-
str(int(np.allclose(a, b, atol=atol, rtol=rtol)))
312+
str(int(paddle.allclose(a, b, atol=atol, rtol=rtol)))
316313
for a, b in zip(expected_out, compiled_out)
317314
)
318315

319316

320317
def get_cmp_max_diff(expected_out, compiled_out):
321318
return " ".join(
322-
str(np.max(np.abs(a - b)).item()) for a, b in zip(expected_out, compiled_out)
319+
str(paddle.max(paddle.abs(a - b)).item())
320+
for a, b in zip(expected_out, compiled_out)
323321
)
324322

325323

326324
def get_cmp_mean_diff(expected_out, compiled_out):
327325
return " ".join(
328-
str(np.mean(np.abs(a - b)).item()) for a, b in zip(expected_out, compiled_out)
326+
str(paddle.mean(paddle.abs(a - b)).item())
327+
for a, b in zip(expected_out, compiled_out)
329328
)
330329

331330

332331
def get_cmp_diff_count(expected_out, compiled_out, atol, rtol):
333332
return " ".join(
334-
str(np.sum(~np.isclose(a, b, atol=atol, rtol=rtol)).item())
333+
str(paddle.sum(~paddle.isclose(a, b, atol=atol, rtol=rtol)).item())
335334
for a, b in zip(expected_out, compiled_out)
336335
)
337336

@@ -344,9 +343,11 @@ def test_multi_models(args):
344343
"-m graph_net.paddle.test_compiler",
345344
f"--model-path {model_path}",
346345
f"--compiler {args.compiler}",
346+
f"--device {args.device}",
347347
f"--warmup {args.warmup}",
348348
f"--trials {args.trials}",
349349
f"--log-prompt {args.log_prompt}",
350+
f"--output-dir {args.output_dir}",
350351
]
351352
)
352353
cmd_ret = os.system(cmd)

0 commit comments

Comments
 (0)