1313from udlp .ops import MSELoss , BCELoss
1414
1515class DenoisingAutoencoder (nn .Module ):
16- def __init__ (self , in_features , out_features , activation = "relu" ):
16+ def __init__ (self , in_features , out_features , activation = "relu" ,
17+ dropout = 0.2 , tied = False ):
1718 super (self .__class__ , self ).__init__ ()
1819 self .weight = Parameter (torch .Tensor (out_features , in_features ))
20+ if tied :
21+ self .deweight = self .weight .t ()
22+ else :
23+ self .deweight = Parameter (torch .Tensor (in_features , out_features ))
1924 self .bias = Parameter (torch .Tensor (out_features ))
2025 self .vbias = Parameter (torch .Tensor (in_features ))
2126
2227 if activation == "relu" :
2328 self .enc_act_func = nn .ReLU ()
2429 elif activation == "sigmoid" :
2530 self .enc_act_func = nn .Sigmoid ()
26- self .dropout = nn .Dropout (p = 0.2 )
31+ self .dropout = nn .Dropout (p = dropout )
2732
2833 self .reset_parameters ()
2934
3035 def reset_parameters (self ):
3136 stdv = 1. / math .sqrt (self .weight .size (1 ))
3237 self .weight .data .uniform_ (- stdv , stdv )
3338 self .bias .data .uniform_ (- stdv , stdv )
34- stdv = 1. / math .sqrt (self .vbias .size (0 ))
39+ stdv = 1. / math .sqrt (self .deweight .size (1 ))
40+ self .deweight .data .uniform_ (- stdv , stdv )
3541 self .vbias .data .uniform_ (- stdv , stdv )
3642
3743 def forward (self , x ):
@@ -44,13 +50,26 @@ def encode(self, x, train=True):
4450 self .dropout .eval ()
4551 return self .dropout (self .enc_act_func (F .linear (x , self .weight , self .bias )))
4652
53+ def encodeBatch (self , dataloader ):
54+ encoded = []
55+ for batch_idx , (inputs , _ ) in enumerate (dataloader ):
56+ inputs = inputs .view (inputs .size (0 ), - 1 ).float ()
57+ if use_cuda :
58+ inputs = inputs .cuda ()
59+ inputs = Variable (inputs )
60+ hidden = self .encode (inputs , train = False )
61+ encoded .append (hidden .data .cpu ())
62+
63+ encoded = torch .cat (encoded , dim = 0 )
64+ return encoded
65+
4766 def decode (self , x , binary = False ):
4867 if not binary :
49- return F .linear (x , self .weight . t () , self .vbias )
68+ return F .linear (x , self .deweight , self .vbias )
5069 else :
51- return F .sigmoid (F .linear (x , self .weight . t () , self .vbias ))
70+ return F .sigmoid (F .linear (x , self .deweight , self .vbias ))
5271
53- def fit (self , data_x , valid_x , lr = 0.001 , batch_size = 128 , num_epochs = 10 , corrupt = 0.5 ,
72+ def fit (self , trainloader , validloader , lr = 0.001 , batch_size = 128 , num_epochs = 10 , corrupt = 0.3 ,
5473 loss_type = "mse" ):
5574 """
5675 data_x: FloatTensor
@@ -60,17 +79,11 @@ def fit(self, data_x, valid_x, lr=0.001, batch_size=128, num_epochs=10, corrupt=
6079 if use_cuda :
6180 self .cuda ()
6281 print ("=====Denoising Autoencoding layer=======" )
63- optimizer = optim .Adam (filter (lambda p : p .requires_grad , self .parameters ()), lr = lr , betas = ( 0.9 , 0.9 ) )
82+ optimizer = optim .Adam (filter (lambda p : p .requires_grad , self .parameters ()), lr = lr )
6483 if loss_type == "mse" :
6584 criterion = MSELoss ()
6685 elif loss_type == "cross-entropy" :
6786 criterion = BCELoss ()
68- trainset = Dataset (data_x , data_x )
69- trainloader = torch .utils .data .DataLoader (
70- trainset , batch_size = batch_size , shuffle = True , num_workers = 2 )
71- validset = Dataset (valid_x , valid_x )
72- validloader = torch .utils .data .DataLoader (
73- validset , batch_size = 1000 , shuffle = False , num_workers = 2 )
7487
7588 # validate
7689 total_loss = 0.0
@@ -87,14 +100,15 @@ def fit(self, data_x, valid_x, lr=0.001, batch_size=128, num_epochs=10, corrupt=
87100 outputs = self .decode (hidden )
88101
89102 valid_recon_loss = criterion (outputs , inputs )
90- total_loss += valid_recon_loss .data [0 ] * inputs . size ()[ 0 ]
103+ total_loss += valid_recon_loss .data [0 ] * len ( inputs )
91104 total_num += inputs .size ()[0 ]
92105
93106 valid_loss = total_loss / total_num
94107 print ("#Epoch 0: Valid Reconstruct Loss: %.3f" % (valid_loss ))
95108
96109 for epoch in range (num_epochs ):
97110 # train 1 epoch
111+ train_loss = 0.0
98112 for batch_idx , (inputs , _ ) in enumerate (trainloader ):
99113 inputs = inputs .view (inputs .size (0 ), - 1 ).float ()
100114 inputs_corr = masking_noise (inputs , corrupt )
@@ -111,12 +125,12 @@ def fit(self, data_x, valid_x, lr=0.001, batch_size=128, num_epochs=10, corrupt=
111125 else :
112126 outputs = self .decode (hidden )
113127 recon_loss = criterion (outputs , inputs )
128+ train_loss += recon_loss .data [0 ]* len (inputs )
114129 recon_loss .backward ()
115130 optimizer .step ()
116131
117132 # validate
118- total_loss = 0.0
119- total_num = 0
133+ valid_loss = 0.0
120134 for batch_idx , (inputs , _ ) in enumerate (validloader ):
121135 inputs = inputs .view (inputs .size (0 ), - 1 ).float ()
122136 if use_cuda :
@@ -129,10 +143,8 @@ def fit(self, data_x, valid_x, lr=0.001, batch_size=128, num_epochs=10, corrupt=
129143 outputs = self .decode (hidden )
130144
131145 valid_recon_loss = criterion (outputs , inputs )
132- total_loss += valid_recon_loss .data [0 ] * inputs .size ()[0 ]
133- total_num += inputs .size ()[0 ]
146+ valid_loss += valid_recon_loss .data [0 ] * len (inputs )
134147
135- valid_loss = total_loss / total_num
136148 print ("#Epoch %3d: Reconstruct Loss: %.3f, Valid Reconstruct Loss: %.3f" % (
137- epoch + 1 , recon_loss . data [ 0 ] , valid_loss ))
149+ epoch + 1 , train_loss / len ( trainloader . dataset ) , valid_loss / len ( validloader . dataset ) ))
138150
0 commit comments