Skip to content

Commit 96385a0

Browse files
committed
add sdae
1 parent f6d2992 commit 96385a0

File tree

7 files changed

+301
-49
lines changed

7 files changed

+301
-49
lines changed

.DS_Store

0 Bytes
Binary file not shown.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ Note:
1919

2020
* The pretrained weights is important to initialize the weights of VaDE.
2121
* Unlike the original code using combined training and test data for training and evaluation, I split the training and test data, and only use training data for training and test data for evaluation. I think it is a more appropriate way to evaluate the method for generalization.
22-
* As found, with above evaluation scheme and training for 3000 epochs, the clustering accuracy achieved is 93.65\%.
22+
* As found, with above evaluation scheme and training for 3000 epochs, the clustering accuracy achieved is 94\%.

dataset/.DS_Store

0 Bytes
Binary file not shown.

test/test_dae.py

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,31 @@
1+
import sys
2+
sys.path.append("..")
13
import torch
4+
import torch.utils.data
5+
from torchvision import datasets, transforms
26
import numpy as np
7+
import argparse
38
from udlp.autoencoder.denoisingAutoencoder import DenoisingAutoencoder
4-
from utils import readData
59

610
if __name__ == "__main__":
7-
# from lib.Tox21_Data import read
8-
# x_tr_t, y_tr_t, x_valid_t, y_valid_t, x_te_t, y_te_t = read("./dataset/tox21/", target=0)
11+
parser = argparse.ArgumentParser(description='VAE MNIST Example')
12+
parser.add_argument('--lr', type=float, default=0.002, metavar='N',
13+
help='learning rate for training (default: 0.001)')
14+
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
15+
help='input batch size for training (default: 128)')
16+
parser.add_argument('--epochs', type=int, default=10, metavar='N',
17+
help='number of epochs to train (default: 10)')
18+
args = parser.parse_args()
19+
20+
train_loader = torch.utils.data.DataLoader(
21+
datasets.MNIST('../dataset/mnist', train=True, download=True,
22+
transform=transforms.ToTensor()),
23+
batch_size=args.batch_size, shuffle=True, num_workers=2)
24+
test_loader = torch.utils.data.DataLoader(
25+
datasets.MNIST('../dataset/mnist', train=False, transform=transforms.ToTensor()),
26+
batch_size=args.batch_size, shuffle=False, num_workers=2)
927

10-
label_name = ['World', 'Sports', 'Business', 'Sci/Tech']
11-
training_num, valid_num, test_num, vocab_size = 110000, 10000, 7600, 10000
12-
training_file = 'dataset/agnews_training_110K_10K-TFIDF-words.txt'
13-
valid_file = 'dataset/agnews_valid_10K_10K-TFIDF-words.txt'
14-
test_file = 'dataset/agnews_test_7600_10K-TFIDF-words.txt'
15-
16-
randgen = np.random.RandomState(13)
17-
trainX, trainY = readData(training_file, training_num, vocab_size, randgen)
18-
validX, validY = readData(valid_file, valid_num, vocab_size)
19-
testX, testY = readData(test_file, test_num, vocab_size)
20-
21-
# preprocess, normalize each dimension to be [0, 1] for cross-entropy loss
22-
train_max = torch.max(trainX, dim=0, keepdim=True)[0]
23-
valid_max = torch.max(validX, dim=0, keepdim=True)[0]
24-
test_max = torch.max(testX, dim=0, keepdim=True)[0]
25-
print(train_max.size())
26-
print(valid_max.size())
27-
print(test_max.size())
28-
x_max = torch.max(torch.cat((train_max, valid_max, test_max), 0), dim=0, keepdim=True)[0]
29-
trainX.div_(x_max)
30-
validX.div_(x_max)
31-
testX.div_(x_max)
32-
33-
in_features = trainX.size()[1]
28+
in_features = 784
3429
out_features = 500
3530
dae = DenoisingAutoencoder(in_features, out_features)
36-
dae.fit(trainX, validX, lr=1e-3, num_epochs=10, loss_type="cross-entropy")
31+
dae.fit(train_loader, test_loader, lr=args.lr, num_epochs=args.epochs, loss_type="cross-entropy")

test/test_sdae.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import sys
2+
sys.path.append("..")
3+
import torch
4+
import torch.utils.data
5+
from torchvision import datasets, transforms
6+
import numpy as np
7+
import argparse
8+
from udlp.autoencoder.stackedDAE import StackedDAE
9+
10+
if __name__ == "__main__":
11+
parser = argparse.ArgumentParser(description='VAE MNIST Example')
12+
parser.add_argument('--lr', type=float, default=0.002, metavar='N',
13+
help='learning rate for training (default: 0.001)')
14+
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
15+
help='input batch size for training (default: 128)')
16+
parser.add_argument('--pretrainepochs', type=int, default=10, metavar='N',
17+
help='number of epochs to train (default: 10)')
18+
parser.add_argument('--epochs', type=int, default=10, metavar='N',
19+
help='number of epochs to train (default: 10)')
20+
args = parser.parse_args()
21+
22+
train_loader = torch.utils.data.DataLoader(
23+
datasets.MNIST('../dataset/mnist', train=True, download=True,
24+
transform=transforms.ToTensor()),
25+
batch_size=args.batch_size, shuffle=True, num_workers=2)
26+
test_loader = torch.utils.data.DataLoader(
27+
datasets.MNIST('../dataset/mnist', train=False, transform=transforms.ToTensor()),
28+
batch_size=args.batch_size, shuffle=False, num_workers=2)
29+
30+
in_features = 784
31+
out_features = 500
32+
sdae = StackedDAE(input_dim=784, z_dim=10, binary=True,
33+
encodeLayer=[500,500,2000], decodeLayer=[2000,500,500], activation="relu",
34+
dropout=0)
35+
sdae.pretrain(train_loader, test_loader, lr=args.lr, batch_size=args.batch_size,
36+
num_epochs=args.pretrainepochs, corrupt=0.3, loss_type="cross-entropy")
37+
sdae.fit(train_loader, test_loader, lr=args.lr, num_epochs=args.epochs, corrupt=0.3, loss_type="cross-entropy")

udlp/autoencoder/denoisingAutoencoder.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,31 @@
1313
from udlp.ops import MSELoss, BCELoss
1414

1515
class DenoisingAutoencoder(nn.Module):
16-
def __init__(self, in_features, out_features, activation="relu"):
16+
def __init__(self, in_features, out_features, activation="relu",
17+
dropout=0.2, tied=False):
1718
super(self.__class__, self).__init__()
1819
self.weight = Parameter(torch.Tensor(out_features, in_features))
20+
if tied:
21+
self.deweight = self.weight.t()
22+
else:
23+
self.deweight = Parameter(torch.Tensor(in_features, out_features))
1924
self.bias = Parameter(torch.Tensor(out_features))
2025
self.vbias = Parameter(torch.Tensor(in_features))
2126

2227
if activation=="relu":
2328
self.enc_act_func = nn.ReLU()
2429
elif activation=="sigmoid":
2530
self.enc_act_func = nn.Sigmoid()
26-
self.dropout = nn.Dropout(p=0.2)
31+
self.dropout = nn.Dropout(p=dropout)
2732

2833
self.reset_parameters()
2934

3035
def reset_parameters(self):
3136
stdv = 1. / math.sqrt(self.weight.size(1))
3237
self.weight.data.uniform_(-stdv, stdv)
3338
self.bias.data.uniform_(-stdv, stdv)
34-
stdv = 1. / math.sqrt(self.vbias.size(0))
39+
stdv = 1. / math.sqrt(self.deweight.size(1))
40+
self.deweight.data.uniform_(-stdv, stdv)
3541
self.vbias.data.uniform_(-stdv, stdv)
3642

3743
def forward(self, x):
@@ -44,13 +50,26 @@ def encode(self, x, train=True):
4450
self.dropout.eval()
4551
return self.dropout(self.enc_act_func(F.linear(x, self.weight, self.bias)))
4652

53+
def encodeBatch(self, dataloader):
54+
encoded = []
55+
for batch_idx, (inputs, _) in enumerate(dataloader):
56+
inputs = inputs.view(inputs.size(0), -1).float()
57+
if use_cuda:
58+
inputs = inputs.cuda()
59+
inputs = Variable(inputs)
60+
hidden = self.encode(inputs, train=False)
61+
encoded.append(hidden.data.cpu())
62+
63+
encoded = torch.cat(encoded, dim=0)
64+
return encoded
65+
4766
def decode(self, x, binary=False):
4867
if not binary:
49-
return F.linear(x, self.weight.t(), self.vbias)
68+
return F.linear(x, self.deweight, self.vbias)
5069
else:
51-
return F.sigmoid(F.linear(x, self.weight.t(), self.vbias))
70+
return F.sigmoid(F.linear(x, self.deweight, self.vbias))
5271

53-
def fit(self, data_x, valid_x, lr=0.001, batch_size=128, num_epochs=10, corrupt=0.5,
72+
def fit(self, trainloader, validloader, lr=0.001, batch_size=128, num_epochs=10, corrupt=0.3,
5473
loss_type="mse"):
5574
"""
5675
data_x: FloatTensor
@@ -60,17 +79,11 @@ def fit(self, data_x, valid_x, lr=0.001, batch_size=128, num_epochs=10, corrupt=
6079
if use_cuda:
6180
self.cuda()
6281
print("=====Denoising Autoencoding layer=======")
63-
optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=lr, betas=(0.9, 0.9))
82+
optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.parameters()), lr=lr)
6483
if loss_type=="mse":
6584
criterion = MSELoss()
6685
elif loss_type=="cross-entropy":
6786
criterion = BCELoss()
68-
trainset = Dataset(data_x, data_x)
69-
trainloader = torch.utils.data.DataLoader(
70-
trainset, batch_size=batch_size, shuffle=True, num_workers=2)
71-
validset = Dataset(valid_x, valid_x)
72-
validloader = torch.utils.data.DataLoader(
73-
validset, batch_size=1000, shuffle=False, num_workers=2)
7487

7588
# validate
7689
total_loss = 0.0
@@ -87,14 +100,15 @@ def fit(self, data_x, valid_x, lr=0.001, batch_size=128, num_epochs=10, corrupt=
87100
outputs = self.decode(hidden)
88101

89102
valid_recon_loss = criterion(outputs, inputs)
90-
total_loss += valid_recon_loss.data[0] * inputs.size()[0]
103+
total_loss += valid_recon_loss.data[0] * len(inputs)
91104
total_num += inputs.size()[0]
92105

93106
valid_loss = total_loss / total_num
94107
print("#Epoch 0: Valid Reconstruct Loss: %.3f" % (valid_loss))
95108

96109
for epoch in range(num_epochs):
97110
# train 1 epoch
111+
train_loss = 0.0
98112
for batch_idx, (inputs, _) in enumerate(trainloader):
99113
inputs = inputs.view(inputs.size(0), -1).float()
100114
inputs_corr = masking_noise(inputs, corrupt)
@@ -111,12 +125,12 @@ def fit(self, data_x, valid_x, lr=0.001, batch_size=128, num_epochs=10, corrupt=
111125
else:
112126
outputs = self.decode(hidden)
113127
recon_loss = criterion(outputs, inputs)
128+
train_loss += recon_loss.data[0]*len(inputs)
114129
recon_loss.backward()
115130
optimizer.step()
116131

117132
# validate
118-
total_loss = 0.0
119-
total_num = 0
133+
valid_loss = 0.0
120134
for batch_idx, (inputs, _) in enumerate(validloader):
121135
inputs = inputs.view(inputs.size(0), -1).float()
122136
if use_cuda:
@@ -129,10 +143,8 @@ def fit(self, data_x, valid_x, lr=0.001, batch_size=128, num_epochs=10, corrupt=
129143
outputs = self.decode(hidden)
130144

131145
valid_recon_loss = criterion(outputs, inputs)
132-
total_loss += valid_recon_loss.data[0] * inputs.size()[0]
133-
total_num += inputs.size()[0]
146+
valid_loss += valid_recon_loss.data[0] * len(inputs)
134147

135-
valid_loss = total_loss / total_num
136148
print("#Epoch %3d: Reconstruct Loss: %.3f, Valid Reconstruct Loss: %.3f" % (
137-
epoch+1, recon_loss.data[0], valid_loss))
149+
epoch+1, train_loss / len(trainloader.dataset), valid_loss / len(validloader.dataset)))
138150

0 commit comments

Comments
 (0)