@@ -367,7 +367,7 @@ def _save_validation_data(self, Xs, Xt):
367367 else :
368368 self .Xs_ = Xs
369369 self .Xt_ = Xt
370- self .src_index_ = np . arange ( len ( Xs ))
370+ self .src_index_ = None
371371
372372
373373 def _get_target_data (self , X , y ):
@@ -458,7 +458,7 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
458458 if yt is not None :
459459 Xt , yt = check_arrays (Xt , yt )
460460 else :
461- Xt = check_array (Xt )
461+ Xt = check_array (Xt , ensure_2d = True , allow_nd = True )
462462 set_random_seed (self .random_state )
463463
464464 self ._save_validation_data (X , Xt )
@@ -857,7 +857,7 @@ def __init__(self,
857857 self ._self_setattr_tracking = True
858858
859859
860- def fit (self , X , y , Xt = None , yt = None , domains = None , ** fit_params ):
860+ def fit (self , X , y = None , Xt = None , yt = None , domains = None , ** fit_params ):
861861 """
862862 Fit Model. Note that ``fit`` does not reset
863863 the model but extend the training.
@@ -867,7 +867,7 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
867867 X : array or Tensor
868868 Source input data.
869869
870- y : array or Tensor
870+ y : array or Tensor (default=None)
871871 Source output data.
872872
873873 Xt : array (default=None)
@@ -889,71 +889,126 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
889889 Returns
890890 -------
891891 self : returns an instance of self
892- """
892+ """
893893 set_random_seed (self .random_state )
894894
895895 # 1. Initialize networks
896896 if not hasattr (self , "_is_fitted" ):
897897 self ._is_fitted = True
898898 self ._initialize_networks ()
899- self ._initialize_weights (X .shape [1 :])
899+ if isinstance (X , tf .data .Dataset ):
900+ first_elem = next (iter (X ))
901+ if (not isinstance (first_elem , tuple ) or
902+ not len (first_elem )== 2 ):
903+ raise ValueError ("When first argument is a dataset. "
904+ "It should return (x, y) tuples." )
905+ else :
906+ shape = first_elem [0 ].shape
907+ else :
908+ shape = X .shape [1 :]
909+ self ._initialize_weights (shape )
910+
911+ # 2. Get Fit params
912+ fit_params = self ._filter_params (super ().fit , fit_params )
900913
901- # 2. Prepare dataset
914+ verbose = fit_params .get ("verbose" , 1 )
915+ epochs = fit_params .get ("epochs" , 1 )
916+ batch_size = fit_params .pop ("batch_size" , 32 )
917+ shuffle = fit_params .pop ("shuffle" , True )
918+ validation_data = fit_params .pop ("validation_data" , None )
919+ validation_split = fit_params .pop ("validation_split" , 0. )
920+ validation_batch_size = fit_params .pop ("validation_batch_size" , batch_size )
921+
922+ # 3. Prepare datasets
923+
924+ ### 3.1 Source
925+ if not isinstance (X , tf .data .Dataset ):
926+ check_arrays (X , y )
927+ if len (y .shape ) <= 1 :
928+ y = y .reshape (- 1 , 1 )
929+
930+ # Single source
931+ if domains is None :
932+ self .n_sources_ = 1
933+
934+ dataset_Xs = tf .data .Dataset .from_tensor_slices (X )
935+ dataset_ys = tf .data .Dataset .from_tensor_slices (y )
936+
937+ # Multisource
938+ else :
939+ domains = self ._check_domains (domains )
940+ self .n_sources_ = int (np .max (domains )+ 1 )
941+
942+ sizes = [np .sum (domains == dom )
943+ for dom in range (self .n_sources_ )]
944+
945+ max_size = np .max (sizes )
946+ repeats = np .ceil (max_size / sizes )
947+
948+ dataset_Xs = tf .data .Dataset .zip (tuple (
949+ tf .data .Dataset .from_tensor_slices (X [domains == dom ]).repeat (repeats [dom ])
950+ for dom in range (self .n_sources_ ))
951+ )
952+
953+ dataset_ys = tf .data .Dataset .zip (tuple (
954+ tf .data .Dataset .from_tensor_slices (y [domains == dom ]).repeat (repeats [dom ])
955+ for dom in range (self .n_sources_ ))
956+ )
957+
958+ dataset_src = tf .data .Dataset .zip ((dataset_Xs , dataset_ys ))
959+
960+ else :
961+ dataset_src = X
962+
963+ ### 3.2 Target
902964 Xt , yt = self ._get_target_data (Xt , yt )
965+ if not isinstance (Xt , tf .data .Dataset ):
966+ if yt is None :
967+ check_array (Xt , ensure_2d = True , allow_nd = True )
968+ dataset_tgt = tf .data .Dataset .from_tensor_slices (Xt )
903969
904- check_arrays (X , y )
905- if len (y .shape ) <= 1 :
906- y = y .reshape (- 1 , 1 )
970+ else :
971+ check_arrays (Xt , yt )
972+
973+ if len (yt .shape ) <= 1 :
974+ yt = yt .reshape (- 1 , 1 )
975+
976+ dataset_Xt = tf .data .Dataset .from_tensor_slices (Xt )
977+ dataset_yt = tf .data .Dataset .from_tensor_slices (yt )
978+ dataset_tgt = tf .data .Dataset .zip ((dataset_Xt , dataset_yt ))
907979
908- if yt is None :
909- yt = y
910- check_array (Xt )
911980 else :
912- check_arrays (Xt , yt )
913-
914- if len (yt .shape ) <= 1 :
915- yt = yt .reshape (- 1 , 1 )
981+ dataset_tgt = Xt
916982
917983 self ._save_validation_data (X , Xt )
918984
919- domains = fit_params .pop ("domains" , None )
920-
921- if domains is None :
922- domains = np .zeros (len (X ))
923-
924- domains = self ._check_domains (domains )
925-
926- self .n_sources_ = int (np .max (domains )+ 1 )
927-
928- sizes = np .array (
929- [np .sum (domains == dom ) for dom in range (self .n_sources_ )]+
930- [len (Xt )])
931-
932- max_size = np .max (sizes )
933- repeats = np .ceil (max_size / sizes )
934-
935- dataset_X = tf .data .Dataset .zip (tuple (
936- tf .data .Dataset .from_tensor_slices (X [domains == dom ]).repeat (repeats [dom ])
937- for dom in range (self .n_sources_ ))+
938- (tf .data .Dataset .from_tensor_slices (Xt ).repeat (repeats [- 1 ]),)
939- )
940-
941- dataset_y = tf .data .Dataset .zip (tuple (
942- tf .data .Dataset .from_tensor_slices (y [domains == dom ]).repeat (repeats [dom ])
943- for dom in range (self .n_sources_ ))+
944- (tf .data .Dataset .from_tensor_slices (yt ).repeat (repeats [- 1 ]),)
945- )
946-
947-
948- # 3. Get Fit params
949- fit_params = self ._filter_params (super ().fit , fit_params )
950-
951- verbose = fit_params .get ("verbose" , 1 )
952- epochs = fit_params .get ("epochs" , 1 )
953- batch_size = fit_params .pop ("batch_size" , 32 )
954- shuffle = fit_params .pop ("shuffle" , True )
985+ # 4. Get validation data
986+ # validation_data = self._check_validation_data(validation_data,
987+ # validation_batch_size,
988+ # shuffle)
989+
990+ if validation_data is None and validation_split > 0. :
991+ if shuffle :
992+ dataset_src = dataset_src .shuffle (buffer_size = 1024 )
993+ frac = int (len (dataset_src )* validation_split )
994+ validation_data = dataset_src .take (frac )
995+ dataset_src = dataset_src .skip (frac )
996+ validation_data = validation_data .batch (batch_size )
997+
998+ # 5. Set datasets
999+ try :
1000+ max_size = max (len (dataset_src ), len (dataset_tgt ))
1001+ repeat_src = np .ceil (max_size / len (dataset_src ))
1002+ repeat_tgt = np .ceil (max_size / len (dataset_tgt ))
1003+
1004+ dataset_src = dataset_src .repeat (repeat_src )
1005+ dataset_tgt = dataset_tgt .repeat (repeat_tgt )
1006+
1007+ self .total_steps_ = float (np .ceil (max_size / batch_size )* epochs )
1008+ except :
1009+ pass
9551010
956- # 4 . Pretraining
1011+ # 5 . Pretraining
9571012 if not hasattr (self , "pretrain_" ):
9581013 if not hasattr (self , "pretrain" ):
9591014 self .pretrain_ = False
@@ -980,36 +1035,39 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
9801035 pre_epochs = prefit_params .pop ("epochs" , epochs )
9811036 pre_batch_size = prefit_params .pop ("batch_size" , batch_size )
9821037 pre_shuffle = prefit_params .pop ("shuffle" , shuffle )
1038+ prefit_params .pop ("validation_data" , None )
1039+ prefit_params .pop ("validation_split" , None )
1040+ prefit_params .pop ("validation_batch_size" , None )
9831041
9841042 if pre_shuffle :
985- dataset = tf .data .Dataset .zip ((dataset_X , dataset_y )).shuffle (buffer_size = 1024 ).batch (pre_batch_size )
1043+ dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt )).shuffle (buffer_size = 1024 ).batch (pre_batch_size )
9861044 else :
987- dataset = tf .data .Dataset .zip ((dataset_X , dataset_y )).batch (pre_batch_size )
1045+ dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt )).batch (pre_batch_size )
9881046
989- hist = super ().fit (dataset , epochs = pre_epochs , verbose = pre_verbose , ** prefit_params )
1047+ hist = super ().fit (dataset , validation_data = validation_data ,
1048+ epochs = pre_epochs , verbose = pre_verbose , ** prefit_params )
9901049
9911050 for k , v in hist .history .items ():
9921051 self .pretrain_history_ [k ] = self .pretrain_history_ .get (k , []) + v
9931052
9941053 self ._initialize_pretain_networks ()
995-
996- # 5. Training
1054+
1055+ # 6. Compile
9971056 if (not self ._is_compiled ) or (self .pretrain_ ):
9981057 self .compile ()
9991058
10001059 if not hasattr (self , "history_" ):
10011060 self .history_ = {}
10021061
1062+ # .7 Training
10031063 if shuffle :
1004- dataset = tf .data .Dataset .zip ((dataset_X , dataset_y )).shuffle (buffer_size = 1024 ).batch (batch_size )
1064+ dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt )).shuffle (buffer_size = 1024 ).batch (batch_size )
10051065 else :
1006- dataset = tf .data .Dataset .zip ((dataset_X , dataset_y )).batch (batch_size )
1007-
1066+ dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt )).batch (batch_size )
1067+
10081068 self .pretrain_ = False
1009- self .steps_ = tf .Variable (0. )
1010- self .total_steps_ = float (np .ceil (max_size / batch_size )* epochs )
10111069
1012- hist = super ().fit (dataset , ** fit_params )
1070+ hist = super ().fit (dataset , validation_data = validation_data , ** fit_params )
10131071
10141072 for k , v in hist .history .items ():
10151073 self .history_ [k ] = self .history_ .get (k , []) + v
@@ -1188,6 +1246,12 @@ def compile(self,
11881246 super ().compile (
11891247 ** compile_params
11901248 )
1249+
1250+ # Set optimizer for encoder and discriminator
1251+ if not hasattr (self , "optimizer_enc" ):
1252+ self .optimizer_enc = self .optimizer
1253+ if not hasattr (self , "optimizer_disc" ):
1254+ self .optimizer_disc = self .optimizer
11911255
11921256
11931257 def call (self , inputs ):
@@ -1199,10 +1263,6 @@ def train_step(self, data):
11991263 # Unpack the data.
12001264 Xs , Xt , ys , yt = self ._unpack_data (data )
12011265
1202- # Single source
1203- Xs = Xs [0 ]
1204- ys = ys [0 ]
1205-
12061266 # Run forward pass.
12071267 with tf .GradientTape () as tape :
12081268 y_pred = self (Xs , training = True )
@@ -1376,7 +1436,7 @@ def score_estimator(self, X, y, sample_weight=None):
13761436 score : float
13771437 Score.
13781438 """
1379- if np .prod (X .shape ) <= 10 ** 8 :
1439+ if hasattr ( X , "shape" ) and np .prod (X .shape ) <= 10 ** 8 :
13801440 score = self .evaluate (
13811441 X , y ,
13821442 sample_weight = sample_weight ,
@@ -1390,6 +1450,22 @@ def score_estimator(self, X, y, sample_weight=None):
13901450 if isinstance (score , (tuple , list )):
13911451 score = score [0 ]
13921452 return score
1453+
1454+
1455+ # def _check_validation_data(self, validation_data, batch_size, shuffle):
1456+ # if isinstance(validation_data, tuple):
1457+ # X_val = validation_data[0]
1458+ # y_val = validation_data[1]
1459+
1460+ # validation_data = tf.data.Dataset.zip(
1461+ # (tf.data.Dataset.from_tensor_slices(X_val),
1462+ # tf.data.Dataset.from_tensor_slices(y_val))
1463+ # )
1464+ # if shuffle:
1465+ # validation_data = validation_data.shuffle(buffer_size=1024).batch(batch_size)
1466+ # else:
1467+ # validation_data = validation_data.batch(batch_size)
1468+ # return validation_data
13931469
13941470
13951471 def _get_legal_params (self , params ):
@@ -1405,7 +1481,7 @@ def _get_legal_params(self, params):
14051481 if (optimizer is not None ) and (not isinstance (optimizer , str )):
14061482 legal_params_fct .append (optimizer .__init__ )
14071483
1408- legal_params = ["domain" , "val_sample_size" ]
1484+ legal_params = ["domain" , "val_sample_size" , "optimizer_enc" , "optimizer_disc" ]
14091485 for func in legal_params_fct :
14101486 args = [
14111487 p .name
@@ -1439,13 +1515,17 @@ def _initialize_weights(self, shape_X):
14391515
14401516
14411517 def _unpack_data (self , data ):
1442- data_X = data [0 ]
1443- data_y = data [1 ]
1444- Xs = data_X [:- 1 ]
1445- Xt = data_X [- 1 ]
1446- ys = data_y [:- 1 ]
1447- yt = data_y [- 1 ]
1448- return Xs , Xt , ys , ys
1518+ data_src = data [0 ]
1519+ data_tgt = data [1 ]
1520+ Xs = data_src [0 ]
1521+ ys = data_src [1 ]
1522+ if isinstance (data_tgt , tuple ):
1523+ Xt = data_tgt [0 ]
1524+ yt = data_tgt [1 ]
1525+ return Xs , Xt , ys , yt
1526+ else :
1527+ Xt = data_tgt
1528+ return Xs , Xt , ys , None
14491529
14501530
14511531 def _get_disc_metrics (self , ys_disc , yt_disc ):
0 commit comments