-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathrun.py
More file actions
31 lines (28 loc) · 916 Bytes
/
run.py
File metadata and controls
31 lines (28 loc) · 916 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from model import AdaGAE
import torch
import data_loader as loader
import warnings
import numpy as np
warnings.filterwarnings('ignore')
dataset = loader.UMIST
[data, labels] = loader.load_data(dataset)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
X = torch.Tensor(data).to(device)
input_dim = data.shape[1]
layers = None
if dataset is loader.USPS:
layers = [input_dim, 128, 64]
else:
layers = [input_dim, 256, 64]
accs = [];
nmis = [];
for lam in np.power(2.0, np.array(range(-10, 10, 2))):
for neighbors in [5]:
print('-----lambda={}, neighbors={}'.format(lam, neighbors))
gae = AdaGAE(X, labels, layers=layers, num_neighbors=neighbors, lam=lam, max_iter=50, max_epoch=10,
update=True, learning_rate=5*10**-3, inc_neighbors=5, device=device)
acc, nmi = gae.run()
accs.append(acc)
nmis.append(nmi)
print(accs)
print(nmis)