Skip to content

Commit 70386d1

Browse files
committed
Allow to specify the verified samples.
1 parent 4f8d8da commit 70386d1

File tree

2 files changed

+44
-20
lines changed

2 files changed

+44
-20
lines changed

graph_net/paddle/test_compiler.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -317,30 +317,42 @@ def get_cmp_diff_count(expected_out, compiled_out, atol, rtol):
317317

318318

319319
def test_multi_models(args):
320+
verified_samples = None
321+
if args.verified_samples_path is not None:
322+
assert os.path.isfile(args.verified_samples_path)
323+
graphnet_root = path_utils.get_graphnet_root()
324+
print(f"graphnet_root: {graphnet_root}")
325+
verified_samples = []
326+
with open(args.verified_samples_path, "r") as f:
327+
for line in f.readlines():
328+
verified_samples.append(os.path.join(graphnet_root, line.strip()))
329+
320330
sample_idx = 0
321331
failed_samples = []
322332
for model_path in path_utils.get_recursively_model_path(args.model_path):
323-
print(f"[{sample_idx}] test_compiler, model_path: {model_path}")
324-
cmd = " ".join(
325-
[
326-
sys.executable,
327-
"-m graph_net.paddle.test_compiler",
328-
f"--model-path {model_path}",
329-
f"--compiler {args.compiler}",
330-
f"--device {args.device}",
331-
f"--warmup {args.warmup}",
332-
f"--trials {args.trials}",
333-
f"--log-prompt {args.log_prompt}",
334-
f"--output-dir {args.output_dir}",
335-
]
336-
)
337-
cmd_ret = os.system(cmd)
338-
# assert cmd_ret == 0, f"{cmd_ret=}, {cmd=}"
339-
if cmd_ret != 0:
340-
failed_samples.append(model_path)
341-
sample_idx += 1
333+
if verified_samples is None or os.path.abspath(model_path) in verified_samples:
334+
print(f"[{sample_idx}] test_compiler, model_path: {model_path}")
335+
cmd = " ".join(
336+
[
337+
sys.executable,
338+
"-m graph_net.paddle.test_compiler",
339+
f"--model-path {model_path}",
340+
f"--compiler {args.compiler}",
341+
f"--device {args.device}",
342+
f"--warmup {args.warmup}",
343+
f"--trials {args.trials}",
344+
f"--log-prompt {args.log_prompt}",
345+
]
346+
)
347+
cmd_ret = os.system(cmd)
348+
# assert cmd_ret == 0, f"{cmd_ret=}, {cmd=}"
349+
if cmd_ret != 0:
350+
failed_samples.append(model_path)
351+
sample_idx += 1
342352

343-
print(f"Totally {sample_idx} samples, failed {len(failed_samples)} samples.")
353+
print(
354+
f"Totally {sample_idx} verified samples, failed {len(failed_samples)} samples."
355+
)
344356
for model_path in failed_samples:
345357
print(f"- {model_path}")
346358

@@ -393,5 +405,12 @@ def main(args):
393405
default="graph-net-test-compiler-log",
394406
help="Log prompt for performance log filtering.",
395407
)
408+
parser.add_argument(
409+
"--verified-samples-path",
410+
type=str,
411+
required=False,
412+
default=None,
413+
help="Path to model file(s), each subdirectory containing graph_net.json will be regarded as a model",
414+
)
396415
args = parser.parse_args()
397416
main(args=args)

graph_net/path_utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
import os
2+
import graph_net
3+
4+
5+
def get_graphnet_root():
6+
return os.path.dirname(os.path.dirname(graph_net.__file__))
27

38

49
def is_single_model_dir(model_dir):

0 commit comments

Comments
 (0)