Skip to content

Commit 92dd5e4

Browse files
committed
Fix input dtype for paddle validate and enable counting paddle samples.
1 parent b3fb4bf commit 92dd5e4

File tree

3 files changed

+28
-19
lines changed

3 files changed

+28
-19
lines changed

graph_net/paddle/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,4 +170,4 @@ def replay_tensor(info):
170170
if "data" in info and info["data"] is not None:
171171
return info["data"].to(device)
172172

173-
return paddle.randn(shape).to(dtype).to(device) * std * 1e-3 + 1e-2
173+
return (paddle.randn(shape).cast(dtype).to(device) * std * 1e-3 + 1e-2).cast(dtype)

graph_net/paddle/validate.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,13 @@ def main(args):
7171
print(np.argmin(y), np.argmax(y))
7272
if isinstance(y, paddle.Tensor):
7373
print(y.shape)
74-
elif isinstance(y, list) or isinstance(y, tuple):
74+
elif (isinstance(y, list) or isinstance(y, tuple)) and all(
75+
isinstance(obj, paddle.Tensor) for obj in y
76+
):
77+
# list of paddle.Tensor
7578
print(y[0].shape)
7679
else:
77-
raise ValueError("Illegal Return Value.")
80+
raise ValueError("Illegal return value.")
7881

7982
if not args.no_check_redundancy:
8083
print("Check redundancy ...")

tools/count_sample.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,27 @@
44

55
filename = os.path.abspath(__file__)
66
root_dir = os.path.dirname(os.path.dirname(filename))
7-
samples_dir = os.path.join(root_dir, "samples")
8-
model_categories = os.listdir(samples_dir)
7+
framework2dirname = {
8+
"torch": "samples",
9+
"paddle": "paddle_samples",
10+
}
911

10-
graph_net_count = 0
11-
graph_net_dict = {}
12-
for category in model_categories:
13-
category_dir = os.path.join(samples_dir, category)
14-
if os.path.isdir(category_dir):
15-
graph_net_dict[category] = 0
16-
for root, dirs, files in os.walk(category_dir):
17-
if "graph_net.json" in files:
18-
graph_net_count += 1
19-
graph_net_dict[category] += 1
12+
for framework in ["torch", "paddle"]:
13+
samples_dir = os.path.join(root_dir, framework2dirname[framework])
14+
model_categories = os.listdir(samples_dir)
2015

21-
print(f"Number of graph_net.json files: {graph_net_count}")
22-
for name, number in graph_net_dict.items():
23-
print(f"- {name:24}: {number}")
24-
print()
16+
graph_net_count = 0
17+
graph_net_dict = {}
18+
for category in model_categories:
19+
category_dir = os.path.join(samples_dir, category)
20+
if os.path.isdir(category_dir):
21+
graph_net_dict[category] = 0
22+
for root, dirs, files in os.walk(category_dir):
23+
if "graph_net.json" in files:
24+
graph_net_count += 1
25+
graph_net_dict[category] += 1
26+
27+
print(f"Number of {framework} samples: {graph_net_count}")
28+
for name, number in graph_net_dict.items():
29+
print(f"- {name:24}: {number}")
30+
print()

0 commit comments

Comments
 (0)