Skip to content

Commit 6c61641

Browse files
committed
[Add] paddle samples and testing tools
1 parent 47e9a2a commit 6c61641

File tree

12 files changed

+2306
-0
lines changed

12 files changed

+2306
-0
lines changed

graph_net/paddle/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""
2+
GraphNet Paddle Implementation
3+
"""
4+
5+
from .samples_util import get_default_samples_directory
6+
7+
__all__ = ["get_default_samples_directory"]
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
from . import utils
2+
import argparse
3+
import importlib.util
4+
import inspect
5+
from pathlib import Path
6+
from typing import Type, Any
7+
import sys
8+
import os
9+
import os.path
10+
from dataclasses import dataclass
11+
from contextlib import contextmanager
12+
import time
13+
import glob
14+
15+
16+
def get_recursively_model_pathes(root_dir):
17+
for sub_dir in _get_recursively_model_pathes(root_dir):
18+
yield os.path.realpath(sub_dir)
19+
20+
21+
def _get_recursively_model_pathes(root_dir):
22+
if is_single_model_dir(root_dir):
23+
yield root_dir
24+
return
25+
for sub_dir in get_immediate_subdirectory_paths(root_dir):
26+
if is_single_model_dir(sub_dir):
27+
yield sub_dir
28+
else:
29+
yield from get_recursively_model_pathes(sub_dir)
30+
31+
32+
def get_immediate_subdirectory_paths(parent_dir):
33+
return [
34+
sub_dir
35+
for name in os.listdir(parent_dir)
36+
for sub_dir in [os.path.join(parent_dir, name)]
37+
if os.path.isdir(sub_dir)
38+
]
39+
40+
41+
def is_single_model_dir(model_dir):
42+
return os.path.isfile(f"{model_dir}/graph_net.json")
43+
44+
45+
def main(args):
46+
assert os.path.isdir(args.model_path)
47+
assert os.path.isdir(args.graph_net_samples_path)
48+
current_model_graph_hash_pathes = set(
49+
graph_hash_path
50+
for model_path in get_recursively_model_pathes(args.model_path)
51+
for graph_hash_path in [f"{model_path}/graph_hash.txt"]
52+
)
53+
graph_hash2graph_net_model_path = {
54+
graph_hash: graph_hash_path
55+
for model_path in get_recursively_model_pathes(args.graph_net_samples_path)
56+
for graph_hash_path in [f"{model_path}/graph_hash.txt"]
57+
if os.path.isfile(graph_hash_path)
58+
if graph_hash_path not in current_model_graph_hash_pathes
59+
for graph_hash in [open(graph_hash_path).read()]
60+
}
61+
for current_model_graph_hash_path in current_model_graph_hash_pathes:
62+
graph_hash = open(current_model_graph_hash_path).read()
63+
assert (
64+
graph_hash not in graph_hash2graph_net_model_path
65+
), f"Redundant models detected. old-model-path:{current_model_graph_hash_path}, new-model-path:{graph_hash2graph_net_model_path[graph_hash]}."
66+
67+
68+
if __name__ == "__main__":
69+
parser = argparse.ArgumentParser(description="Test compiler performance.")
70+
parser.add_argument(
71+
"--model-path",
72+
type=str,
73+
required=True,
74+
help="Path to model file(s), each subdirectory containing graph_net.json will be regarded as a model",
75+
)
76+
parser.add_argument(
77+
"--graph-net-samples-path",
78+
type=str,
79+
required=False,
80+
default="default",
81+
help="Path to GraphNet samples",
82+
)
83+
args = parser.parse_args()
84+
main(args=args)
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from . import utils
2+
import argparse
3+
import importlib.util
4+
import inspect
5+
from pathlib import Path
6+
from typing import Type, Any
7+
import sys
8+
import os
9+
import os.path
10+
from dataclasses import dataclass
11+
from contextlib import contextmanager
12+
import time
13+
import glob
14+
import shutil
15+
16+
17+
def get_recursively_model_pathes(root_dir):
18+
for sub_dir in _get_recursively_model_pathes(root_dir):
19+
yield os.path.realpath(sub_dir)
20+
21+
22+
def _get_recursively_model_pathes(root_dir):
23+
if is_single_model_dir(root_dir):
24+
yield root_dir
25+
return
26+
for sub_dir in get_immediate_subdirectory_paths(root_dir):
27+
if is_single_model_dir(sub_dir):
28+
yield sub_dir
29+
else:
30+
yield from get_recursively_model_pathes(sub_dir)
31+
32+
33+
def get_immediate_subdirectory_paths(parent_dir):
34+
return [
35+
sub_dir
36+
for name in os.listdir(parent_dir)
37+
for sub_dir in [os.path.join(parent_dir, name)]
38+
if os.path.isdir(sub_dir)
39+
]
40+
41+
42+
def is_single_model_dir(model_dir):
43+
return os.path.isfile(f"{model_dir}/graph_net.json")
44+
45+
46+
def main(args):
47+
assert os.path.isdir(args.model_path)
48+
assert os.path.isdir(args.graph_net_samples_path)
49+
current_model_graph_hash_pathes = set(
50+
graph_hash_path
51+
for model_path in get_recursively_model_pathes(args.model_path)
52+
for graph_hash_path in [f"{model_path}/graph_hash.txt"]
53+
)
54+
graph_hash2graph_net_model_path = {
55+
graph_hash: graph_hash_path
56+
for model_path in get_recursively_model_pathes(args.graph_net_samples_path)
57+
for graph_hash_path in [f"{model_path}/graph_hash.txt"]
58+
if os.path.isfile(graph_hash_path)
59+
if graph_hash_path not in current_model_graph_hash_pathes
60+
for graph_hash in [open(graph_hash_path).read()]
61+
}
62+
for current_model_graph_hash_path in current_model_graph_hash_pathes:
63+
graph_hash = open(current_model_graph_hash_path).read()
64+
if graph_hash not in graph_hash2graph_net_model_path:
65+
continue
66+
directory = os.path.dirname(current_model_graph_hash_path)
67+
shutil.rmtree(directory)
68+
os.makedirs(directory, exist_ok=True)
69+
70+
71+
if __name__ == "__main__":
72+
parser = argparse.ArgumentParser(description="Test compiler performance.")
73+
parser.add_argument(
74+
"--model-path",
75+
type=str,
76+
required=True,
77+
help="Path to model file(s), each subdirectory containing graph_net.json will be regarded as a model",
78+
)
79+
parser.add_argument(
80+
"--graph-net-samples-path",
81+
type=str,
82+
required=False,
83+
default="default",
84+
help="Path to GraphNet samples",
85+
)
86+
args = parser.parse_args()
87+
main(args=args)

graph_net/paddle/samples_util.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import graph_net
2+
import os
3+
4+
5+
def get_default_samples_directory():
6+
return f"{os.path.dirname(graph_net.__file__)}/../samples"

0 commit comments

Comments
 (0)