1212
1313from adapt .base import BaseAdaptDeep , make_insert_doc
1414from adapt .utils import (check_arrays , check_network , get_default_task ,
15- set_random_seed , check_estimator , check_sample_weight )
15+ set_random_seed , check_estimator , check_sample_weight , check_if_compiled )
1616
1717EPS = np .finfo (np .float32 ).eps
1818
@@ -141,8 +141,21 @@ def _initialize_networks(self):
141141 name = "weighter" )
142142 self .sigma_ = tf .Variable (self .sigma_init ,
143143 trainable = self .update_sigma )
144-
145-
144+
145+ if not hasattr (self , "estimator_" ):
146+ self .estimator_ = check_estimator (self .estimator ,
147+ copy = self .copy ,
148+ force_copy = True )
149+
150+
151+ def _initialize_weights (self , shape_X ):
152+ if hasattr (self , "weighter_" ):
153+ self .weighter_ .build ((None ,) + shape_X )
154+ self .build ((None ,) + shape_X )
155+ if isinstance (self .estimator_ , Model ):
156+ self .estimator_ .build ((None ,) + shape_X )
157+
158+
146159 def pretrain_step (self , data ):
147160 # Unpack the data.
148161 Xs , Xt , ys , yt = self ._unpack_data (data )
@@ -163,7 +176,7 @@ def pretrain_step(self, data):
163176 gradients = tape .gradient (loss , trainable_vars )
164177
165178 # Update weights
166- self .optimizer .apply_gradients (zip (gradients , trainable_vars ))
179+ self .pretrain_optimizer .apply_gradients (zip (gradients , trainable_vars ))
167180
168181 logs = {"loss" : loss }
169182 return logs
@@ -200,7 +213,7 @@ def train_step(self, data):
200213
201214 # Update weights
202215 self .optimizer .apply_gradients (zip (gradients , trainable_vars ))
203- self .optimizer .apply_gradients (zip (gradients_sigma , [self .sigma_ ]))
216+ self .optimizer_sigma .apply_gradients (zip (gradients_sigma , [self .sigma_ ]))
204217
205218 # Return a dict mapping metric names to current value
206219 logs = {"loss" : loss , "sigma" : self .sigma_ }
@@ -214,6 +227,26 @@ def fit(self, X, y=None, Xt=None, yt=None, domains=None,
214227 return self
215228
216229
230+ def compile (self ,
231+ optimizer = None ,
232+ loss = None ,
233+ metrics = None ,
234+ loss_weights = None ,
235+ weighted_metrics = None ,
236+ run_eagerly = None ,
237+ steps_per_execution = None ,
238+ ** kwargs ):
239+ super ().compile (optimizer = optimizer ,
240+ loss = loss ,
241+ metrics = metrics ,
242+ loss_weights = loss_weights ,
243+ weighted_metrics = weighted_metrics ,
244+ run_eagerly = run_eagerly ,
245+ steps_per_execution = steps_per_execution ,
246+ ** kwargs )
247+ self .optimizer_sigma = self .optimizer .__class__ .from_config (self .optimizer .get_config ())
248+
249+
217250 def fit_weights (self , Xs , Xt , ** fit_params ):
218251 """
219252 Fit importance weighting.
@@ -276,22 +309,23 @@ def fit_estimator(self, X, y, sample_weight=None,
276309 X , y = check_arrays (X , y , accept_sparse = True )
277310 set_random_seed (random_state )
278311
279- if (not warm_start ) or (not hasattr (self , "estimator_" )):
280- estimator = self .estimator
281- self .estimator_ = check_estimator (estimator ,
312+ if not hasattr (self , "estimator_" ):
313+ self .estimator_ = check_estimator (self .estimator ,
282314 copy = self .copy ,
283315 force_copy = True )
284- if isinstance (self .estimator_ , Model ):
285- compile_params = {}
286- if estimator ._is_compiled :
287- compile_params ["loss" ] = deepcopy (estimator .loss )
288- compile_params ["optimizer" ] = deepcopy (estimator .optimizer )
289- else :
290- raise ValueError ("The given `estimator` argument"
291- " is not compiled yet. "
292- "Please give a compiled estimator or "
293- "give a `loss` and `optimizer` arguments." )
294- self .estimator_ .compile (** compile_params )
316+
317+ estimator = self .estimator
318+ if isinstance (self .estimator_ , Model ):
319+ compile_params = {}
320+ if check_if_compiled (estimator ):
321+ compile_params ["loss" ] = deepcopy (estimator .loss )
322+ compile_params ["optimizer" ] = deepcopy (estimator .optimizer )
323+ else :
324+ raise ValueError ("The given `estimator` argument"
325+ " is not compiled yet. "
326+ "Please give a compiled estimator or "
327+ "give a `loss` and `optimizer` arguments." )
328+ self .estimator_ .compile (** compile_params )
295329
296330 fit_args = [
297331 p .name
0 commit comments