forked from yuziGuo/PolyFilterPlayground
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_dataset.py
More file actions
45 lines (38 loc) · 1.41 KB
/
test_dataset.py
File metadata and controls
45 lines (38 loc) · 1.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from data.platonov_dataloader import platonov_dataloader
from data.geom_dataloader import geom_dataloader
from data.linkx_dataloader import linkx_dataloader
def test_platonov():
ds_names = ['questions', 'roman-empire', 'minesweeper', 'tolokers', 'amazon_ratings']
for ds in ds_names:
loader = platonov_dataloader(ds, 'cuda:1', True)
loader.load_data()
loader.load_a_mask()
print('Success!')
print(ds)
print(f"features.shape: {loader.features.shape}")
print(f"labels.shape: {loader.labels.shape}")
print(f"n_classes: {loader.n_classes}")
print(f"number of nodes in each class:")
for c in range(loader.n_classes):
print((loader.labels==c).sum())
print()
def test_linkx():
ds_names = ['genius']
for ds in ds_names:
loader = linkx_dataloader(ds, 'cuda:1', True)
loader.load_data()
loader.load_a_mask()
print('Success!')
print(ds)
print(f"features.shape: {loader.features.shape}")
print(f"labels.shape: {loader.labels.shape}")
print(f"n_classes: {loader.n_classes}")
print(f"number of nodes in each class:")
for c in range(loader.n_classes):
print((loader.labels==c).sum())
print()
def test_geom():
loader = geom_dataloader('chameleon', 'cuda:1', False)
loader.load_data()
if __name__=='__main__':
test_platonov()