-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathedge2vec.py
More file actions
39 lines (29 loc) · 1.25 KB
/
edge2vec.py
File metadata and controls
39 lines (29 loc) · 1.25 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
import sys
from configargparse import ArgumentParser
from src.deep_pre import DeepPre
from src.csv2tf_neg import convert
from src.deep_negative import pretrain
def parse_arguments():
parser = ArgumentParser(description='Arguments For edge2vec')
group = parser.add_argument_group('Base Configs')
group.add_argument('-i', '--input', help='path to the input graph file', type=str, required=True)
# group.add_argument('-o', '--output', help='path to the output embedding file', type=str, required=True)
group.add_argument('-m', '--model', help='the output directory of model files', type=str, required=True)
group.add_argument('-n', '--num', help='the maximum num of the node', type=int, required=True)
group.add_argument('-s', '--sample', help='the num of negative samples', type=int, required=True)
args = parser.parse_args()
return args
def main():
args = parse_arguments()
sys.argv = sys.argv[:1]
if not os.path.exists(args.model):
os.makedirs(args.model)
pre = DeepPre(args.input, args.model, args.num, args.num, args.sample)
pre.read_data()
pre.calculate()
pre.write_csv()
convert(args.model, args.num * 2)
pretrain(args.num * 2, args.model)
if __name__ == '__main__':
main()