@@ -106,25 +106,25 @@ def copyParam(self, daeLayers):
106106 every = 3
107107 # input layer
108108 # copy encoder weight
109- self .encoder [0 ].weight .copy_ (daeLayers [l ].weight )
110- self .encoder [0 ].bias .copy_ (daeLayers [l ].bias )
111- self ._dec .weight .copy_ (daeLayers [l ].deweight )
112- self ._dec .bias .copy_ (daeLayers [l ].vbias )
109+ self .encoder [0 ].weight .data . copy_ (daeLayers [0 ].weight . data )
110+ self .encoder [0 ].bias .data . copy_ (daeLayers [0 ].bias . data )
111+ self ._dec .weight .data . copy_ (daeLayers [0 ].deweight . data )
112+ self ._dec .bias .data . copy_ (daeLayers [0 ].vbias . data )
113113
114114 for l in range (1 , len (self .layers )- 2 ):
115115 # copy encoder weight
116- self .encoder [l * every ].weight .copy_ (daeLayers [l ].weight )
117- self .encoder [l * every ].bias .copy_ (daeLayers [l ].bias )
116+ self .encoder [l * every ].weight .data . copy_ (daeLayers [l ].weight . data )
117+ self .encoder [l * every ].bias .data . copy_ (daeLayers [l ].bias . data )
118118
119119 # copy decoder weight
120- self .decoder [- (l - 1 )* every - 1 ].weight .copy_ (daeLayers [l ].deweight )
121- self .decoder [- (l - 1 )* every - 1 ].bias .copy_ (daeLayers [l ].vbias )
120+ self .decoder [- (l - 1 )* every - 2 ].weight .data . copy_ (daeLayers [l ].deweight . data )
121+ self .decoder [- (l - 1 )* every - 2 ].bias .data . copy_ (daeLayers [l ].vbias . data )
122122
123123 # z layer
124- self ._enc_mu .weight .copy_ (daeLayers [- 1 ].weight )
125- self ._enc_mu .bias .copy_ (daeLayers [- 1 ].bias )
126- self .decoder [0 ].weight .copy_ (daeLayers [- 1 ].deweight )
127- self .decoder [0 ].bias .copy_ (daeLayers [- 1 ].vbias )
124+ self ._enc_mu .weight .data . copy_ (daeLayers [- 1 ].weight . data )
125+ self ._enc_mu .bias .data . copy_ (daeLayers [- 1 ].bias . data )
126+ self .decoder [0 ].weight .data . copy_ (daeLayers [- 1 ].deweight . data )
127+ self .decoder [0 ].bias .data . copy_ (daeLayers [- 1 ].vbias . data )
128128
129129 def fit (self , trainloader , validloader , lr = 0.001 , num_epochs = 10 , corrupt = 0.3 ,
130130 loss_type = "mse" ):
@@ -135,7 +135,7 @@ def fit(self, trainloader, validloader, lr=0.001, num_epochs=10, corrupt=0.3,
135135 use_cuda = torch .cuda .is_available ()
136136 if use_cuda :
137137 self .cuda ()
138- print ("=====Denoising Autoencoding layer=======" )
138+ print ("=====Stacked Denoising Autoencoding layer=======" )
139139 optimizer = optim .Adam (filter (lambda p : p .requires_grad , self .parameters ()), lr = lr )
140140 if loss_type == "mse" :
141141 criterion = MSELoss ()
@@ -150,11 +150,7 @@ def fit(self, trainloader, validloader, lr=0.001, num_epochs=10, corrupt=0.3,
150150 if use_cuda :
151151 inputs = inputs .cuda ()
152152 inputs = Variable (inputs )
153- hidden = self .encode (inputs )
154- if loss_type == "cross-entropy" :
155- outputs = self .decode (hidden , binary = True )
156- else :
157- outputs = self .decode (hidden )
153+ z , outputs = self .forward (inputs )
158154
159155 valid_recon_loss = criterion (outputs , inputs )
160156 total_loss += valid_recon_loss .data [0 ] * len (inputs )
@@ -176,11 +172,7 @@ def fit(self, trainloader, validloader, lr=0.001, num_epochs=10, corrupt=0.3,
176172 inputs = Variable (inputs )
177173 inputs_corr = Variable (inputs_corr )
178174
179- hidden = self .encode (inputs_corr )
180- if loss_type == "cross-entropy" :
181- outputs = self .decode (hidden , binary = True )
182- else :
183- outputs = self .decode (hidden )
175+ z , outputs = self .forward (inputs_corr )
184176 recon_loss = criterion (outputs , inputs )
185177 train_loss += recon_loss .data [0 ]* len (inputs )
186178 recon_loss .backward ()
@@ -193,11 +185,7 @@ def fit(self, trainloader, validloader, lr=0.001, num_epochs=10, corrupt=0.3,
193185 if use_cuda :
194186 inputs = inputs .cuda ()
195187 inputs = Variable (inputs )
196- hidden = self .encode (inputs , train = False )
197- if loss_type == "cross-entropy" :
198- outputs = self .decode (hidden , binary = True )
199- else :
200- outputs = self .decode (hidden )
188+ z , outputs = self .forward (inputs )
201189
202190 valid_recon_loss = criterion (outputs , inputs )
203191 valid_loss += valid_recon_loss .data [0 ] * len (inputs )
0 commit comments