1414from pyg_lib .testing import remap_keys , withDataset , withSeed
1515
1616argparser = argparse .ArgumentParser ('Hetero neighbor sample benchmark' )
17- argparser .add_argument ('--batch-sizes' , nargs = '+' , type = int , default = [
18- 512 ,
19- 1024 ,
20- 2048 ,
21- 4096 ,
22- 8192 ,
23- ])
17+ argparser .add_argument (
18+ '--batch-sizes' ,
19+ nargs = '+' ,
20+ type = int ,
21+ default = [
22+ 512 ,
23+ 1024 ,
24+ 2048 ,
25+ 4096 ,
26+ 8192 ,
27+ ],
28+ )
2429
2530# TODO (kgajdamo): Support undirected hetero graphs
2631# argparser.add_argument('--directed', action='store_true')
2732argparser .add_argument ('--disjoint' , action = 'store_true' )
28- argparser .add_argument ('--num_neighbors' , type = ast .literal_eval , default = [
29- [- 1 ],
30- [15 , 10 , 5 ],
31- [20 , 15 , 10 ],
32- ])
33+ argparser .add_argument (
34+ '--num_neighbors' ,
35+ type = ast .literal_eval ,
36+ default = [
37+ [- 1 ],
38+ [15 , 10 , 5 ],
39+ [20 , 15 , 10 ],
40+ ],
41+ )
3342# TODO(kgajdamo): Enable sampling with replacement
3443# argparser.add_argument('--replace', action='store_true')
3544argparser .add_argument ('--shuffle' , action = 'store_true' )
3645argparser .add_argument ('--biased' , action = 'store_true' )
3746argparser .add_argument ('--temporal' , action = 'store_true' )
38- argparser .add_argument ('--temporal-strategy' , choices = ['uniform' , 'last' ],
39- default = 'uniform' )
47+ argparser .add_argument (
48+ '--temporal-strategy' ,
49+ choices = ['uniform' , 'last' ],
50+ default = 'uniform' ,
51+ )
4052argparser .add_argument ('--write-csv' , action = 'store_true' )
41- argparser .add_argument ('--libraries' , nargs = "*" , type = str ,
42- default = ['pyg-lib' , 'torch-sparse' ])
53+ argparser .add_argument (
54+ '--libraries' ,
55+ nargs = '*' ,
56+ type = str ,
57+ default = ['pyg-lib' , 'torch-sparse' ],
58+ )
4359args = argparser .parse_args ()
4460
4561
4864def test_hetero_neighbor (dataset , ** kwargs ):
4965 if args .temporal and not args .disjoint :
5066 raise ValueError (
51- "Temporal sampling needs to create disjoint subgraphs" )
67+ 'Temporal sampling needs to create disjoint subgraphs' ,
68+ )
5269
5370 colptr_dict , row_dict = dataset
5471 num_nodes_dict = {k [- 1 ]: v .size (0 ) - 1 for k , v in colptr_dict .items ()}
@@ -57,7 +74,8 @@ def test_hetero_neighbor(dataset, **kwargs):
5774 if args .temporal :
5875 # generate random timestamps
5976 node_time , _ = torch .sort (
60- torch .randint (0 , 100000 , (num_nodes_dict ['paper' ], )))
77+ torch .randint (0 , 100000 , (num_nodes_dict ['paper' ],)),
78+ )
6179 node_time_dict = {'paper' : node_time }
6280 else :
6381 node_time_dict = None
@@ -75,9 +93,10 @@ def test_hetero_neighbor(dataset, **kwargs):
7593 node_perm = torch .arange (0 , num_nodes_dict ['paper' ])
7694
7795 data = defaultdict (list )
78- for num_neighbors , batch_size in product (args .num_neighbors ,
79- args .batch_sizes ):
80-
96+ for num_neighbors , batch_size in product (
97+ args .num_neighbors ,
98+ args .batch_sizes ,
99+ ):
81100 print (f'batch_size={ batch_size } , num_neighbors={ num_neighbors } ):' )
82101 data ['num_neighbors' ].append (num_neighbors )
83102 data ['batch-size' ].append (batch_size )
0 commit comments