@@ -412,41 +412,65 @@ def _initialize_networks(self):
412412 else :
413413 self .task_ = check_network (self .task ,
414414 copy = self .copy ,
415+ force_copy = True ,
415416 name = "task" )
417+
418+
419+ def _initialize_weights (self , shape_X ):
420+ if hasattr (self , "task_" ):
421+ self .task_ .build ((None ,) + shape_X )
422+ self .build ((None ,) + shape_X )
416423 self ._add_regularization ()
417424
418425
419- def _get_regularizer (self , old_weight , weight , lambda_ = 1. ):
426+ def _get_regularizer (self , old_weight , weight , lambda_ ):
420427 if self .regularizer == "l2" :
421- def regularizer ():
422- return lambda_ * tf .reduce_mean (tf .square (old_weight - weight ))
428+ return lambda_ * tf .reduce_mean (tf .square (old_weight - weight ))
423429 if self .regularizer == "l1" :
424- def regularizer ():
425- return lambda_ * tf .reduce_mean (tf .abs (old_weight - weight ))
430+ return lambda_ * tf .reduce_mean (tf .abs (old_weight - weight ))
426431 return regularizer
427432
428433
434+ def train_step (self , data ):
435+ # Unpack the data.
436+ Xs , Xt , ys , yt = self ._unpack_data (data )
437+
438+ # Run forward pass.
439+ with tf .GradientTape () as tape :
440+ y_pred = self .task_ (Xt , training = True )
441+ if hasattr (self , "_compile_loss" ) and self ._compile_loss is not None :
442+ loss = self ._compile_loss (yt , y_pred )
443+ else :
444+ loss = self .compiled_loss (yt , y_pred )
445+
446+ loss = tf .reduce_mean (loss )
447+ loss += sum (self .losses )
448+ reg_loss = 0.
449+ for i in range (len (self .task_ .trainable_variables )):
450+ reg_loss += self ._get_regularizer (self .old_weights_ [i ],
451+ self .task_ .trainable_variables [i ],
452+ self .lambdas_ [i ])
453+ loss += reg_loss
454+
455+ # Run backwards pass.
456+ gradients = tape .gradient (loss , self .task_ .trainable_variables )
457+ self .optimizer .apply_gradients (zip (gradients , self .task_ .trainable_variables ))
458+ return self ._update_logs (yt , y_pred )
459+
460+
429461 def _add_regularization (self ):
430- i = 0
462+ self . old_weights_ = []
431463 if not hasattr (self .lambdas , "__iter__" ):
432- lambdas = [self .lambdas ]
464+ self . lambdas_ = [self .lambdas ] * len ( self . task_ . weights )
433465 else :
434- lambdas = self .lambdas
466+ self .lambdas_ = (self .lambdas +
467+ [self .lambdas [- 1 ]] * (len (self .task_ .weights ) - len (self .lambdas )))
468+ self .lambdas_ = self .lambdas_ [::- 1 ]
435469
436- for layer in reversed (self .task_ .layers ):
437- if (hasattr (layer , "weights" ) and
438- layer .weights is not None and
439- len (layer .weights ) != 0 ):
440- if i >= len (lambdas ):
441- lambda_ = lambdas [- 1 ]
442- else :
443- lambda_ = lambdas [i ]
444- for weight in reversed (layer .weights ):
445- old_weight = tf .identity (weight )
446- old_weight .trainable = False
447- self .add_loss (self ._get_regularizer (
448- old_weight , weight , lambda_ ))
449- i += 1
470+ for weight in self .task_ .trainable_variables :
471+ old_weight = tf .identity (weight )
472+ old_weight .trainable = False
473+ self .old_weights_ .append (old_weight )
450474
451475
452476 def call (self , inputs ):
0 commit comments