@@ -169,7 +169,7 @@ def pretrain_step(self, data):
169169 Xs , Xt , ys , yt = self ._unpack_data (data )
170170
171171 # loss
172- with tf .GradientTape () as tape :
172+ with tf .GradientTape () as task_tape , tf . GradientTape () as enc_tape :
173173 # Forward pass
174174 Xs_enc = self .encoder_src_ (Xs , training = True )
175175 ys_pred = self .task_ (Xs_enc , training = True )
@@ -179,14 +179,19 @@ def pretrain_step(self, data):
179179
180180 # Compute the loss value
181181 loss = self .task_loss_ (ys , ys_pred )
182- loss += sum (self .task_ .losses ) + sum (self .encoder_src_ .losses )
182+ task_loss = loss + sum (self .task_ .losses )
183+ enc_loss = loss + sum (self .encoder_src_ .losses )
183184
184185 # Compute gradients
185- trainable_vars = self .task_ .trainable_variables + self .encoder_src_ .trainable_variables
186- gradients = tape .gradient (loss , trainable_vars )
186+ trainable_vars_task = self .task_ .trainable_variables
187+ trainable_vars_enc = self .encoder_src_ .trainable_variables
188+
189+ gradients_task = task_tape .gradient (task_loss , trainable_vars_task )
190+ gradients_enc = enc_tape .gradient (enc_loss , trainable_vars_enc )
187191
188192 # Update weights
189- self .optimizer .apply_gradients (zip (gradients , trainable_vars ))
193+ self .optimizer .apply_gradients (zip (gradients_task , trainable_vars_task ))
194+ self .optimizer_enc .apply_gradients (zip (gradients_enc , trainable_vars_enc ))
190195
191196 # Update metrics
192197 self .compiled_metrics .update_state (ys , ys_pred )
@@ -206,55 +211,48 @@ def train_step(self, data):
206211 Xs , Xt , ys , yt = self ._unpack_data (data )
207212
208213 # loss
209- with tf .GradientTape () as task_tape , tf . GradientTape () as enc_tape , tf .GradientTape () as disc_tape :
214+ with tf .GradientTape () as enc_tape , tf .GradientTape () as disc_tape :
210215 # Forward pass
211- Xs_enc = self .encoder_src_ (Xs , training = True )
212- ys_pred = self .task_ (Xs_enc , training = True )
216+ if self .pretrain :
217+ Xs_enc = self .encoder_src_ (Xs , training = False )
218+ else :
219+ # encoder src is not needed if pretrain=False
220+ Xs_enc = Xs
221+
213222 ys_disc = self .discriminator_ (Xs_enc , training = True )
214223
215224 Xt_enc = self .encoder_ (Xt , training = True )
216225 yt_disc = self .discriminator_ (Xt_enc , training = True )
217226
218- # Reshape
219- ys_pred = tf .reshape (ys_pred , tf .shape (ys ))
220-
221227 # Compute the loss value
222- task_loss = self .task_loss_ (ys , ys_pred )
223-
224228 disc_loss = (- tf .math .log (ys_disc + EPS )
225229 - tf .math .log (1 - yt_disc + EPS ))
226230
227231 enc_loss = - tf .math .log (yt_disc + EPS )
228232
229- task_loss = tf .reduce_mean (task_loss )
230233 disc_loss = tf .reduce_mean (disc_loss )
231234 enc_loss = tf .reduce_mean (enc_loss )
232235
233- task_loss += sum (self .task_ .losses )
234236 disc_loss += sum (self .discriminator_ .losses )
235237 enc_loss += sum (self .encoder_ .losses )
236238
237239 # Compute gradients
238- trainable_vars_task = self .task_ .trainable_variables
239240 trainable_vars_enc = self .encoder_ .trainable_variables
240241 trainable_vars_disc = self .discriminator_ .trainable_variables
241242
242- gradients_task = task_tape .gradient (task_loss , trainable_vars_task )
243243 gradients_enc = enc_tape .gradient (enc_loss , trainable_vars_enc )
244244 gradients_disc = disc_tape .gradient (disc_loss , trainable_vars_disc )
245245
246246 # Update weights
247- self .optimizer .apply_gradients (zip (gradients_task , trainable_vars_task ))
248- self .optimizer .apply_gradients (zip (gradients_enc , trainable_vars_enc ))
249- self .optimizer .apply_gradients (zip (gradients_disc , trainable_vars_disc ))
247+ self .optimizer_enc .apply_gradients (zip (gradients_enc , trainable_vars_enc ))
248+ self .optimizer_disc .apply_gradients (zip (gradients_disc , trainable_vars_disc ))
250249
251250 # Update metrics
252- self .compiled_metrics .update_state (ys , ys_pred )
253- self .compiled_loss (ys , ys_pred )
251+ # self.compiled_metrics.update_state(ys, ys_pred)
252+ # self.compiled_loss(ys, ys_pred)
254253 # Return a dict mapping metric names to current value
255- logs = {m .name : m .result () for m in self .metrics }
256- disc_metrics = self ._get_disc_metrics (ys_disc , yt_disc )
257- logs .update (disc_metrics )
254+ # logs = {m.name: m.result() for m in self.metrics}
255+ logs = self ._get_disc_metrics (ys_disc , yt_disc )
258256 return logs
259257
260258
@@ -275,12 +273,14 @@ def _get_disc_metrics(self, ys_disc, yt_disc):
275273 def _initialize_weights (self , shape_X ):
276274 # Init weights encoder
277275 self (np .zeros ((1 ,) + shape_X ))
278- self .encoder_ (np .zeros ((1 ,) + shape_X ))
279276
280277 # Set same weights to encoder_src
281- self .encoder_src_ = check_network (self .encoder_ ,
282- copy = True ,
283- name = "encoder_src" )
278+ if self .pretrain :
279+ # encoder src is not needed if pretrain=False
280+ self .encoder_ (np .zeros ((1 ,) + shape_X ))
281+ self .encoder_src_ = check_network (self .encoder_ ,
282+ copy = True ,
283+ name = "encoder_src" )
284284
285285
286286 def transform (self , X , domain = "tgt" ):
0 commit comments