|
4 | 4 |
|
5 | 5 | filename = os.path.abspath(__file__) |
6 | 6 | root_dir = os.path.dirname(os.path.dirname(filename)) |
7 | | -samples_dir = os.path.join(root_dir, "samples") |
8 | | -model_categories = os.listdir(samples_dir) |
| 7 | +framework2dirname = { |
| 8 | + "torch": "samples", |
| 9 | + "paddle": "paddle_samples", |
| 10 | +} |
9 | 11 |
|
10 | | -graph_net_count = 0 |
11 | | -graph_net_dict = {} |
12 | | -for category in model_categories: |
13 | | - category_dir = os.path.join(samples_dir, category) |
14 | | - if os.path.isdir(category_dir): |
15 | | - graph_net_dict[category] = 0 |
16 | | - for root, dirs, files in os.walk(category_dir): |
17 | | - if "graph_net.json" in files: |
18 | | - graph_net_count += 1 |
19 | | - graph_net_dict[category] += 1 |
| 12 | +for framework in ["torch", "paddle"]: |
| 13 | + samples_dir = os.path.join(root_dir, framework2dirname[framework]) |
| 14 | + model_categories = os.listdir(samples_dir) |
20 | 15 |
|
21 | | -print(f"Number of graph_net.json files: {graph_net_count}") |
22 | | -for name, number in graph_net_dict.items(): |
23 | | - print(f"- {name:24}: {number}") |
24 | | -print() |
| 16 | + graph_net_count = 0 |
| 17 | + graph_net_dict = {} |
| 18 | + for category in model_categories: |
| 19 | + category_dir = os.path.join(samples_dir, category) |
| 20 | + if os.path.isdir(category_dir): |
| 21 | + graph_net_dict[category] = 0 |
| 22 | + for root, dirs, files in os.walk(category_dir): |
| 23 | + if "graph_net.json" in files: |
| 24 | + graph_net_count += 1 |
| 25 | + graph_net_dict[category] += 1 |
| 26 | + |
| 27 | + print(f"Number of {framework} samples: {graph_net_count}") |
| 28 | + for name, number in graph_net_dict.items(): |
| 29 | + print(f"- {name:24}: {number}") |
| 30 | + print() |
0 commit comments