1414from sklearn .exceptions import NotFittedError
1515from tensorflow .keras import Model
1616from tensorflow .keras .wrappers .scikit_learn import KerasClassifier , KerasRegressor
17+ try :
18+ from tensorflow .keras .optimizers .legacy import RMSprop
19+ except :
20+ from tensorflow .keras .optimizers import RMSprop
21+
1722
1823from adapt .utils import (check_estimator ,
1924 check_network ,
@@ -282,8 +287,8 @@ def unsupervised_score(self, Xs, Xt):
282287 score : float
283288 Unsupervised score.
284289 """
285- Xs = check_array (np . array ( Xs ) )
286- Xt = check_array (np . array ( Xt ) )
290+ Xs = check_array (Xs , accept_sparse = True )
291+ Xt = check_array (Xt , accept_sparse = True )
287292
288293 if hasattr (self , "transform" ):
289294 args = [
@@ -306,13 +311,11 @@ def unsupervised_score(self, Xs, Xt):
306311
307312 set_random_seed (self .random_state )
308313 bootstrap_index = np .random .choice (
309- len ( Xs ) , size = len ( Xs ) , replace = True , p = sample_weight )
314+ Xs . shape [ 0 ] , size = Xs . shape [ 0 ] , replace = True , p = sample_weight )
310315 Xs = Xs [bootstrap_index ]
311316 else :
312317 raise ValueError ("The Adapt model should implement"
313318 " a transform or predict_weights methods" )
314- Xs = np .array (Xs )
315- Xt = np .array (Xt )
316319 return normalized_linear_discrepancy (Xs , Xt )
317320
318321
@@ -466,18 +469,27 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
466469 """
467470 Xt , yt = self ._get_target_data (Xt , yt )
468471 X , y = check_arrays (X , y )
472+ self .n_features_in_ = X .shape [1 ]
469473 if yt is not None :
470474 Xt , yt = check_arrays (Xt , yt )
471475 else :
472476 Xt = check_array (Xt , ensure_2d = True , allow_nd = True )
473477 set_random_seed (self .random_state )
478+
479+ self .n_features_in_ = X .shape [1 ]
474480
475481 if hasattr (self , "fit_weights" ):
476482 if self .verbose :
477483 print ("Fit weights..." )
478- self .weights_ = self .fit_weights (Xs = X , Xt = Xt ,
479- ys = y , yt = yt ,
480- domains = domains )
484+ out = self .fit_weights (Xs = X , Xt = Xt ,
485+ ys = y , yt = yt ,
486+ domains = domains )
487+ if isinstance (out , tuple ):
488+ self .weights_ = out [0 ]
489+ X = out [1 ]
490+ y = out [2 ]
491+ else :
492+ self .weights_ = out
481493 if "sample_weight" in fit_params :
482494 fit_params ["sample_weight" ] *= self .weights_
483495 else :
@@ -534,7 +546,7 @@ def fit_estimator(self, X, y, sample_weight=None,
534546 -------
535547 estimator_ : fitted estimator
536548 """
537- X , y = check_arrays (X , y )
549+ X , y = check_arrays (X , y , accept_sparse = True )
538550 set_random_seed (random_state )
539551
540552 if (not warm_start ) or (not hasattr (self , "estimator_" )):
@@ -613,7 +625,7 @@ def predict_estimator(self, X, **predict_params):
613625 y_pred : array
614626 prediction of estimator.
615627 """
616- X = check_array (X , ensure_2d = True , allow_nd = True )
628+ X = check_array (X , ensure_2d = True , allow_nd = True , accept_sparse = True )
617629 predict_params = self ._filter_params (self .estimator_ .predict ,
618630 predict_params )
619631 return self .estimator_ .predict (X , ** predict_params )
@@ -648,7 +660,7 @@ def predict(self, X, domain=None, **predict_params):
648660 y_pred : array
649661 prediction of the Adapt Model.
650662 """
651- X = check_array (X , ensure_2d = True , allow_nd = True )
663+ X = check_array (X , ensure_2d = True , allow_nd = True , accept_sparse = True )
652664 if hasattr (self , "transform" ):
653665 if domain is None :
654666 domain = "tgt"
@@ -700,7 +712,7 @@ def score(self, X, y, sample_weight=None, domain=None):
700712 score : float
701713 estimator score.
702714 """
703- X , y = check_arrays (X , y )
715+ X , y = check_arrays (X , y , accept_sparse = True )
704716
705717 if domain is None :
706718 domain = "target"
@@ -788,7 +800,6 @@ def _get_legal_params(self, params):
788800
789801
790802 def __getstate__ (self ):
791- print ("getting" )
792803 dict_ = {k : v for k , v in self .__dict__ .items ()}
793804 if "estimator_" in dict_ :
794805 if isinstance (dict_ ["estimator_" ], Model ):
@@ -810,7 +821,6 @@ def __getstate__(self):
810821
811822
812823 def __setstate__ (self , dict_ ):
813- print ("setting" )
814824 if "estimator_" in dict_ :
815825 if isinstance (dict_ ["estimator_" ], dict ):
816826 dict_ ["estimator_" ] = self ._from_config_keras_model (
@@ -960,9 +970,10 @@ def fit(self, X, y=None, Xt=None, yt=None, domains=None, **fit_params):
960970 epochs = fit_params .get ("epochs" , 1 )
961971 batch_size = fit_params .pop ("batch_size" , 32 )
962972 shuffle = fit_params .pop ("shuffle" , True )
973+ buffer_size = fit_params .pop ("buffer_size" , None )
963974 validation_data = fit_params .pop ("validation_data" , None )
964975 validation_split = fit_params .pop ("validation_split" , 0. )
965- validation_batch_size = fit_params .pop ("validation_batch_size" , batch_size )
976+ validation_batch_size = fit_params .get ("validation_batch_size" , batch_size )
966977
967978 # 2. Prepare datasets
968979
@@ -1000,8 +1011,7 @@ def fit(self, X, y=None, Xt=None, yt=None, domains=None, **fit_params):
10001011 for dom in range (self .n_sources_ ))
10011012 )
10021013
1003- dataset_src = tf .data .Dataset .zip ((dataset_Xs , dataset_ys ))
1004-
1014+ dataset_src = tf .data .Dataset .zip ((dataset_Xs , dataset_ys ))
10051015 else :
10061016 dataset_src = X
10071017
@@ -1031,47 +1041,62 @@ def fit(self, X, y=None, Xt=None, yt=None, domains=None, **fit_params):
10311041 self ._initialize_networks ()
10321042 if isinstance (Xt , tf .data .Dataset ):
10331043 first_elem = next (iter (Xt ))
1034- if (not isinstance (first_elem , tuple ) or
1035- not len (first_elem )== 2 ):
1036- raise ValueError ("When first argument is a dataset. "
1037- "It should return (x, y) tuples." )
1044+ if not isinstance (first_elem , tuple ):
1045+ shape = first_elem .shape
10381046 else :
10391047 shape = first_elem [0 ].shape
1048+ if self ._check_for_batch (Xt ):
1049+ shape = shape [1 :]
10401050 else :
10411051 shape = Xt .shape [1 :]
10421052 self ._initialize_weights (shape )
10431053
1044- # validation_data = self._check_validation_data(validation_data,
1045- # validation_batch_size,
1046- # shuffle)
1054+
1055+ # 3.5 Get datasets length
1056+ self .length_src_ = self ._get_length_dataset (dataset_src , domain = "src" )
1057+ self .length_tgt_ = self ._get_length_dataset (dataset_tgt , domain = "tgt" )
1058+
10471059
10481060 # 4. Prepare validation dataset
10491061 if validation_data is None and validation_split > 0. :
10501062 if shuffle :
1051- dataset_src = dataset_src .shuffle (buffer_size = 1024 )
1052- frac = int (len (dataset_src )* validation_split )
1063+ dataset_src = dataset_src .shuffle (buffer_size = self .length_src_ ,
1064+ reshuffle_each_iteration = False )
1065+ frac = int (self .length_src_ * validation_split )
10531066 validation_data = dataset_src .take (frac )
10541067 dataset_src = dataset_src .skip (frac )
1055- validation_data = validation_data .batch (batch_size )
1068+ if not self ._check_for_batch (validation_data ):
1069+ validation_data = validation_data .batch (validation_batch_size )
1070+
1071+ if validation_data is not None :
1072+ if isinstance (validation_data , tf .data .Dataset ):
1073+ if not self ._check_for_batch (validation_data ):
1074+ validation_data = validation_data .batch (validation_batch_size )
10561075
1076+
10571077 # 5. Set datasets
10581078 # Same length for src and tgt + complete last batch + shuffle
1059- try :
1060- max_size = max (len (dataset_src ), len (dataset_tgt ))
1061- max_size = np .ceil (max_size / batch_size ) * batch_size
1062- repeat_src = np .ceil (max_size / len (dataset_src ))
1063- repeat_tgt = np .ceil (max_size / len (dataset_tgt ))
1064-
1065- dataset_src = dataset_src .repeat (repeat_src )
1066- dataset_tgt = dataset_tgt .repeat (repeat_tgt )
1067-
1068- self .total_steps_ = float (np .ceil (max_size / batch_size )* epochs )
1069- except :
1070- pass
1071-
10721079 if shuffle :
1073- dataset_src = dataset_src .shuffle (buffer_size = 1024 )
1074- dataset_tgt = dataset_tgt .shuffle (buffer_size = 1024 )
1080+ if buffer_size is None :
1081+ dataset_src = dataset_src .shuffle (buffer_size = self .length_src_ ,
1082+ reshuffle_each_iteration = True )
1083+ dataset_tgt = dataset_tgt .shuffle (buffer_size = self .length_tgt_ ,
1084+ reshuffle_each_iteration = True )
1085+ else :
1086+ dataset_src = dataset_src .shuffle (buffer_size = buffer_size ,
1087+ reshuffle_each_iteration = True )
1088+ dataset_tgt = dataset_tgt .shuffle (buffer_size = buffer_size ,
1089+ reshuffle_each_iteration = True )
1090+
1091+ max_size = max (self .length_src_ , self .length_tgt_ )
1092+ max_size = np .ceil (max_size / batch_size ) * batch_size
1093+ repeat_src = np .ceil (max_size / self .length_src_ )
1094+ repeat_tgt = np .ceil (max_size / self .length_tgt_ )
1095+
1096+ dataset_src = dataset_src .repeat (repeat_src ).take (max_size )
1097+ dataset_tgt = dataset_tgt .repeat (repeat_tgt ).take (max_size )
1098+
1099+ self .total_steps_ = float (np .ceil (max_size / batch_size )* epochs )
10751100
10761101 # 5. Pretraining
10771102 if not hasattr (self , "pretrain_" ):
@@ -1099,14 +1124,14 @@ def fit(self, X, y=None, Xt=None, yt=None, domains=None, **fit_params):
10991124 pre_verbose = prefit_params .pop ("verbose" , verbose )
11001125 pre_epochs = prefit_params .pop ("epochs" , epochs )
11011126 pre_batch_size = prefit_params .pop ("batch_size" , batch_size )
1102- pre_shuffle = prefit_params .pop ("shuffle" , shuffle )
11031127 prefit_params .pop ("validation_data" , None )
1104- prefit_params .pop ("validation_split" , None )
1105- prefit_params .pop ("validation_batch_size" , None )
11061128
11071129 # !!! shuffle is already done
1108- dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt )).batch (pre_batch_size )
1109-
1130+ dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt ))
1131+
1132+ if not self ._check_for_batch (dataset ):
1133+ dataset = dataset .batch (pre_batch_size )
1134+
11101135 hist = super ().fit (dataset , validation_data = validation_data ,
11111136 epochs = pre_epochs , verbose = pre_verbose , ** prefit_params )
11121137
@@ -1123,7 +1148,10 @@ def fit(self, X, y=None, Xt=None, yt=None, domains=None, **fit_params):
11231148 self .history_ = {}
11241149
11251150 # .7 Training
1126- dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt )).batch (batch_size )
1151+ dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt ))
1152+
1153+ if not self ._check_for_batch (dataset ):
1154+ dataset = dataset .batch (batch_size )
11271155
11281156 self .pretrain_ = False
11291157
@@ -1259,7 +1287,8 @@ def compile(self,
12591287 if "_" in name :
12601288 new_name = ""
12611289 for split in name .split ("_" ):
1262- new_name += split [0 ]
1290+ if len (split ) > 0 :
1291+ new_name += split [0 ]
12631292 name = new_name
12641293 else :
12651294 name = name [:3 ]
@@ -1284,7 +1313,7 @@ def compile(self,
12841313
12851314 if ((not "optimizer" in compile_params ) or
12861315 (compile_params ["optimizer" ] is None )):
1287- compile_params ["optimizer" ] = "rmsprop"
1316+ compile_params ["optimizer" ] = RMSprop ()
12881317 else :
12891318 if optimizer is None :
12901319 if not isinstance (compile_params ["optimizer" ], str ):
@@ -1331,7 +1360,8 @@ def train_step(self, data):
13311360 loss = tf .reduce_mean (loss )
13321361
13331362 # Run backwards pass.
1334- self .optimizer .minimize (loss , self .trainable_variables , tape = tape )
1363+ gradients = tape .gradient (loss , self .trainable_variables )
1364+ self .optimizer .apply_gradients (zip (gradients , self .trainable_variables ))
13351365 self .compiled_metrics .update_state (ys , y_pred )
13361366 # Collect metrics to return
13371367 return_metrics = {}
@@ -1573,6 +1603,37 @@ def _initialize_weights(self, shape_X):
15731603 X_enc = self .encoder_ (np .zeros ((1 ,) + shape_X ))
15741604 if hasattr (self , "discriminator_" ):
15751605 self .discriminator_ (X_enc )
1606+
1607+
1608+ def _get_length_dataset (self , dataset , domain = "src" ):
1609+ try :
1610+ length = len (dataset )
1611+ except :
1612+ if self .verbose :
1613+ print ("Computing %s dataset size..." % domain )
1614+ if not hasattr (self , "length_%s_" % domain ):
1615+ length = 0
1616+ for _ in dataset :
1617+ length += 1
1618+ else :
1619+ length = getattr (self , "length_%s_" % domain )
1620+ if self .verbose :
1621+ print ("Done!" )
1622+ return length
1623+
1624+
1625+ def _check_for_batch (self , dataset ):
1626+ if dataset .__class__ .__name__ == "BatchDataset" :
1627+ return True
1628+ if hasattr (dataset , "_input_dataset" ):
1629+ return self ._check_for_batch (dataset ._input_dataset )
1630+ elif hasattr (dataset , "_datasets" ):
1631+ checks = []
1632+ for data in dataset ._datasets :
1633+ checks .append (self ._check_for_batch (data ))
1634+ return np .all (checks )
1635+ else :
1636+ return False
15761637
15771638
15781639 def _unpack_data (self , data ):
@@ -1596,23 +1657,23 @@ def _get_disc_metrics(self, ys_disc, yt_disc):
15961657
15971658 def _initialize_networks (self ):
15981659 if self .encoder is None :
1599- self .encoder_ = get_default_encoder (name = "encoder" )
1660+ self .encoder_ = get_default_encoder (name = "encoder" , state = self . random_state )
16001661 else :
16011662 self .encoder_ = check_network (self .encoder ,
16021663 copy = self .copy ,
16031664 name = "encoder" )
16041665 if self .task is None :
1605- self .task_ = get_default_task (name = "task" )
1666+ self .task_ = get_default_task (name = "task" , state = self . random_state )
16061667 else :
16071668 self .task_ = check_network (self .task ,
16081669 copy = self .copy ,
16091670 name = "task" )
16101671 if self .discriminator is None :
1611- self .discriminator_ = get_default_discriminator (name = "discriminator" )
1672+ self .discriminator_ = get_default_discriminator (name = "discriminator" , state = self . random_state )
16121673 else :
16131674 self .discriminator_ = check_network (self .discriminator ,
16141675 copy = self .copy ,
16151676 name = "discriminator" )
16161677
16171678 def _initialize_pretain_networks (self ):
1618- pass
1679+ pass
0 commit comments