Skip to content

Commit 165ae4b

Browse files
committed
fix bugs in check_validate.sh
1 parent 5db0b63 commit 165ae4b

File tree

3 files changed

+13
-4
lines changed

3 files changed

+13
-4
lines changed

graph_net/torch/check_redundant_incrementally.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def main(args):
6363
graph_hash = open(current_model_graph_hash_path).read()
6464
assert (
6565
graph_hash not in graph_hash2graph_net_model_path
66-
), f"Redundant models detected. old-model-path:{current_model_graph_hash_path}, new-model-path:{graph_hash2graph_net_model_path[graph_hash]}."
66+
), f"Redundant models detected. old-model-path:{graph_hash2graph_net_model_path[graph_hash]}, new-model-path:{current_model_graph_hash_path}."
6767

6868

6969
if __name__ == "__main__":

graph_net/torch/validate.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ def main(args):
2929
if not args.no_check_redundancy:
3030
print("Check redundancy ...")
3131
graph_net_samples_path = (
32-
graph_net.torch.samples_util.get_default_samples_directory()
32+
(graph_net.torch.samples_util.get_default_samples_directory())
33+
if args.graph_net_samples_path is None
34+
else args.graph_net_samples_path
3335
)
3436
cmd = f"{sys.executable} -m graph_net.torch.check_redundant_incrementally --model-path {args.model_path} --graph-net-samples-path {graph_net_samples_path}"
3537
cmd_ret = os.system(cmd)
@@ -49,6 +51,13 @@ def main(args):
4951
required=True,
5052
help="Path to folder e.g '../../samples/torch/resnet18'",
5153
)
54+
parser.add_argument(
55+
"--graph-net-samples-path",
56+
type=str,
57+
required=False,
58+
default=None,
59+
help="Path to GraphNet samples folder. e.g '../../samples'",
60+
)
5261
parser.add_argument(
5362
"--no-check-redundancy",
5463
action="store_true",

tools/ci/check_validate.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ function check_validation() {
4444
fail_name=()
4545
for model_path in ${MODIFIED_MODEL_PATHS[@]}
4646
do
47-
python -m graph_net.torch.validate --model-path ${GRAPH_NET_EXTRACT_WORKSPACE}/${model_path} >&2
47+
python -m graph_net.torch.validate --model-path ${GRAPH_NET_EXTRACT_WORKSPACE}/${model_path} --graph-net-samples-path ${GRAPH_NET_EXTRACT_WORKSPACE}/samples >&2
4848
[ $? -ne 0 ] && fail_name[${#fail_name[@]}]="${model_path}"
4949
done
5050
if [ ${#fail_name[@]} -ne 0 ]
@@ -76,4 +76,4 @@ function main() {
7676
LOG "[INFO] check_validation run success and no error!"
7777
}
7878

79-
main
79+
main

0 commit comments

Comments
 (0)