Skip to content

Commit 0f3d97b

Browse files
committed
added minimal implementation of Deep GPs
1 parent 3f44472 commit 0f3d97b

File tree

2 files changed

+104
-10
lines changed

2 files changed

+104
-10
lines changed

naslib/benchmarks/nas_predictors/submit-all.sh

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,29 @@
11
#!/bin/bash
22

33
#nb101
4-
predictors101=(bananas feedforward gbdt gcn xgb ngb rf dngo \
5-
bohamiann bayes_lin_reg seminas nao gp sparse_gp var_sparse_gp)
4+
#predictors101=(bananas feedforward gbdt gcn xgb ngb rf dngo \
5+
#bohamiann bayes_lin_reg seminas nao gp sparse_gp var_sparse_gp)
66

77
#nb201
8-
predictors201=(bananas feedforward gbdt gcn bonas xgb ngb rf dngo \
9-
bohamiann bayes_lin_reg seminas nao gp sparse_gp var_sparse_gp)
8+
#predictors201=(bananas feedforward gbdt gcn bonas xgb ngb rf dngo \
9+
#bohamiann bayes_lin_reg seminas nao gp sparse_gp var_sparse_gp)
10+
predictors201=(ngb_hp omni)
1011

1112
#nb301
1213
#predictors301=(bananas feedforward gbdt bonas xgb ngb rf dngo \
1314
# bohamiann bayes_lin_reg gp sparse_gp var_sparse_gp nao)
1415

15-
for predictor in ${predictors101[@]}
16-
do
17-
sbatch -J 101-${predictor} slurm_job-nb101.sh $predictor
18-
done
16+
#for predictor in ${predictors101[@]}
17+
#do
18+
#sbatch -J 101-${predictor} slurm_job-nb101.sh $predictor
19+
#done
1920

2021
for predictor in ${predictors201[@]}
2122
do
2223
#sbatch -J 201-${predictor} slurm_job-nb201-c10.sh $predictor
23-
sbatch -J c100-201-${predictor} slurm_job-nb201-c100.sh $predictor
24-
sbatch -J imnet-201-${predictor} slurm_job-nb201-imagenet.sh $predictor
24+
#sbatch -J c100-201-${predictor} slurm_job-nb201-c100.sh $predictor
25+
#sbatch -J imnet-201-${predictor} slurm_job-nb201-imagenet.sh $predictor
26+
sbatch -J imnet-201-${predictor} slurm_job-imgnet.sh $predictor
2527
done
2628

2729
#for predictor in ${predictors301[@]}

naslib/predictors/gp/deep_gp.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
import torch
2+
import pyro.contrib.gp as gp
3+
import pyro.distributions as dist
4+
import numpy as np
5+
6+
from naslib.predictors.gp_base import BaseGPModel
7+
from naslib.utils.utils import AverageMeterGroup, TensorDatasetWithTrans
8+
9+
device = torch.device('cpu') #NOTE: faster on CPU
10+
11+
# TODO
12+
13+
class DeepVarSparseGP(pyro.nn.PyroModule):
14+
def __init__(self, X, y, Xu, mean_fn):
15+
super(DeepVarSparseGP, self).__init__()
16+
self.layer1 = gp.models.VariationalSparseGP(
17+
X,
18+
None,
19+
gp.kernels.RBF(X.shape[1], variance=torch.tensor(5.).double(),
20+
lengthscale=torch.tensor(10.).double()),
21+
Xu=Xu,
22+
likelihood=None,
23+
mean_function=mean_fn,
24+
latent_shape=torch.Size([10]))
25+
26+
h = mean_fn(X).t()
27+
hu = mean_fn(Xu).t()
28+
self.layer2 = gp.models.VariationalSparseGP(
29+
h,
30+
y,
31+
gp.kernels.RBF(10, variance=torch.tensor(5.).double(),
32+
lengthscale=torch.tensor(10.).double()),
33+
Xu=hu,
34+
likelihood=gp.likelihoods.Gaussian(),
35+
latent_shape=torch.Size([1]))
36+
37+
def model(self, X, y):
38+
self.layer1.set_data(X, None)
39+
h_loc, h_var = self.layer1.model()
40+
# approximate with MC sample
41+
h = dist.Normal(h_loc, h_var.sqrt())()
42+
self.layer2.set_data(h.t(), y)
43+
self.layer2.model()
44+
45+
def guide(self, X, y):
46+
self.layer1.guide()
47+
self.layer2.guide()
48+
49+
# make predictions
50+
def forward(self, X_new):
51+
# because prediction is stochastic (due to Monte Carlo sample of hidden layer),
52+
# we make 100 prediction and take the most common one
53+
pred = []
54+
for _ in range(100):
55+
h_loc, h_var = self.layer1(X_new)
56+
h = dist.Normal(h_loc, h_var.sqrt())()
57+
f_loc, f_var = self.layer2(h.t())
58+
pred.append(f_loc.argmax(dim=0))
59+
return torch.stack(pred).mode(dim=0)[0]
60+
61+
62+
class DeepVarSparseGPPredictor(BaseGPModel):
63+
64+
def get_dataset(self, encodings, labels=None):
65+
if labels is None:
66+
return torch.tensor(encodings).double()
67+
else:
68+
return (torch.tensor(encodings).double(),
69+
torch.tensor((labels-self.mean)/self.std).double())
70+
X_tensor = torch.FloatTensor(_xtrain).to(device)
71+
y_tensor = torch.FloatTensor(_ytrain).to(device)
72+
73+
train_data = TensorDataset(X_tensor, y_tensor)
74+
75+
def get_model(self, train_data, **kwargs):
76+
deepgp = DeepVarSparseGP(
77+
78+
def train(self, train_data, optimize_gp_hyper=False):
79+
X_train, y_train = train_data
80+
# initialize the kernel and model
81+
pyro.clear_param_store()
82+
kernel = self.kernel(input_dim=X_train.shape[1])
83+
Xu = torch.arange(10.) / 2.0
84+
Xu.unsqueeze_(-1)
85+
Xu = Xu.expand(10, X_train.shape[1]).double()
86+
likelihood = gp.likelihoods.Gaussian()
87+
self.gpr = gp.models.VariationalSparseGP(X_train, y_train, kernel,
88+
Xu=Xu, likelihood=likelihood,
89+
whiten=True)
90+
91+
return self.gpr
92+

0 commit comments

Comments
 (0)