@@ -367,7 +367,7 @@ def _save_validation_data(self, Xs, Xt):
367367 else :
368368 self .Xs_ = Xs
369369 self .Xt_ = Xt
370- self .src_index_ = np . arange ( len ( Xs ))
370+ self .src_index_ = None
371371
372372
373373 def _get_target_data (self , X , y ):
@@ -857,7 +857,7 @@ def __init__(self,
857857 self ._self_setattr_tracking = True
858858
859859
860- def fit (self , X , y , Xt = None , yt = None , domains = None , ** fit_params ):
860+ def fit (self , X , y = None , Xt = None , yt = None , domains = None , ** fit_params ):
861861 """
862862 Fit Model. Note that ``fit`` does not reset
863863 the model but extend the training.
@@ -867,7 +867,7 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
867867 X : array or Tensor
868868 Source input data.
869869
870- y : array or Tensor
870+ y : array or Tensor (default=None)
871871 Source output data.
872872
873873 Xt : array (default=None)
@@ -896,7 +896,18 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
896896 if not hasattr (self , "_is_fitted" ):
897897 self ._is_fitted = True
898898 self ._initialize_networks ()
899- self ._initialize_weights (X .shape [1 :])
899+ if isinstance (X , tf .data .Dataset ):
900+ first_elem = next (iter (X ))
901+ if (not isinstance (first_elem , tuple ) or
902+ not len (first_elem )== 2 ):
903+ raise ValueError ("When first argument is a dataset. "
904+ "It should return (x, y) tuples." )
905+ else :
906+ shape = first_elem [0 ].shape
907+ else :
908+ shape = X .shape [1 :]
909+ print (shape )
910+ self ._initialize_weights (shape )
900911
901912 # 2. Get Fit params
902913 fit_params = self ._filter_params (super ().fit , fit_params )
@@ -909,65 +920,96 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
909920 validation_split = fit_params .pop ("validation_split" , 0. )
910921 validation_batch_size = fit_params .pop ("validation_batch_size" , batch_size )
911922
912- # 3. Prepare dataset
923+ # 3. Prepare datasets
913924
914925 ### 3.1 Source
915926 if not isinstance (X , tf .data .Dataset ):
916927 check_arrays (X , y )
917928 if len (y .shape ) <= 1 :
918929 y = y .reshape (- 1 , 1 )
919930
931+ # Single source
932+ if domains is None :
933+ self .n_sources_ = 1
934+
935+ dataset_Xs = tf .data .Dataset .from_tensor_slices (X )
936+ dataset_ys = tf .data .Dataset .from_tensor_slices (y )
937+
938+ # Multisource
939+ else :
940+ domains = self ._check_domains (domains )
941+ self .n_sources_ = int (np .max (domains )+ 1 )
942+
943+ sizes = [np .sum (domains == dom )
944+ for dom in range (self .n_sources_ )]
945+
946+ max_size = np .max (sizes )
947+ repeats = np .ceil (max_size / sizes )
948+
949+ dataset_Xs = tf .data .Dataset .zip (tuple (
950+ tf .data .Dataset .from_tensor_slices (X [domains == dom ]).repeat (repeats [dom ])
951+ for dom in range (self .n_sources_ ))
952+ )
953+
954+ dataset_ys = tf .data .Dataset .zip (tuple (
955+ tf .data .Dataset .from_tensor_slices (y [domains == dom ]).repeat (repeats [dom ])
956+ for dom in range (self .n_sources_ ))
957+ )
958+
959+ dataset_src = tf .data .Dataset .zip ((dataset_Xs , dataset_ys ))
960+
961+ else :
962+ dataset_src = X
963+
920964 ### 3.2 Target
921965 Xt , yt = self ._get_target_data (Xt , yt )
922966 if not isinstance (Xt , tf .data .Dataset ):
923967 if yt is None :
924- yt = y
925968 check_array (Xt , ensure_2d = True , allow_nd = True )
969+ dataset_tgt = tf .data .Dataset .from_tensor_slices (Xt )
970+
926971 else :
927972 check_arrays (Xt , yt )
928973
929- if len (yt .shape ) <= 1 :
930- yt = yt .reshape (- 1 , 1 )
974+ if len (yt .shape ) <= 1 :
975+ yt = yt .reshape (- 1 , 1 )
976+
977+ dataset_Xt = tf .data .Dataset .from_tensor_slices (Xt )
978+ dataset_yt = tf .data .Dataset .from_tensor_slices (yt )
979+ dataset_tgt = tf .data .Dataset .zip ((dataset_Xt , dataset_yt ))
980+
981+ else :
982+ dataset_tgt = Xt
931983
932984 self ._save_validation_data (X , Xt )
933985
934- ### 3.3 Domains
935- domains = fit_params .pop ("domains" , None )
936-
937- if domains is None :
938- domains = np .zeros (len (X ))
939-
940- domains = self ._check_domains (domains )
941-
942- self .n_sources_ = int (np .max (domains )+ 1 )
943-
944- sizes = np .array (
945- [np .sum (domains == dom ) for dom in range (self .n_sources_ )]+
946- [len (Xt )])
947-
948- max_size = np .max (sizes )
949- repeats = np .ceil (max_size / sizes )
950-
951- # Split if validation_split
952- # if validation_data is None and validation_split>0.:
953- # frac = int(len(dataset)*validation_split)
954- # validation_data = dataset.take(frac)
955- # dataset = dataset.skip(frac)
956-
986+ # 4. Get validation data
987+ validation_data = self ._check_validation_data (validation_data ,
988+ validation_batch_size ,
989+ shuffle )
957990
958- dataset_X = tf .data .Dataset .zip (tuple (
959- tf .data .Dataset .from_tensor_slices (X [domains == dom ]).repeat (repeats [dom ])
960- for dom in range (self .n_sources_ ))+
961- (tf .data .Dataset .from_tensor_slices (Xt ).repeat (repeats [- 1 ]),)
962- )
963-
964- dataset_y = tf .data .Dataset .zip (tuple (
965- tf .data .Dataset .from_tensor_slices (y [domains == dom ]).repeat (repeats [dom ])
966- for dom in range (self .n_sources_ ))+
967- (tf .data .Dataset .from_tensor_slices (yt ).repeat (repeats [- 1 ]),)
968- )
991+ if validation_data is None and validation_split > 0. :
992+ if shuffle :
993+ dataset_src = dataset_src .shuffle (buffer_size = 1024 )
994+ frac = int (len (dataset_src )* validation_split )
995+ validation_data = dataset_src .take (frac )
996+ dataset_src = dataset_src .skip (frac )
997+ validation_data = validation_data .batch (batch_size )
998+
999+ # 5. Set datasets
1000+ try :
1001+ max_size = max (len (dataset_src ), len (dataset_tgt ))
1002+ repeat_src = np .ceil (max_size / len (dataset_src ))
1003+ repeat_tgt = np .ceil (max_size / len (dataset_tgt ))
1004+
1005+ dataset_src = dataset_src .repeat (repeat_src )
1006+ dataset_tgt = dataset_tgt .repeat (repeat_tgt )
1007+
1008+ self .total_steps_ = float (np .ceil (max_size / batch_size )* epochs )
1009+ except :
1010+ pass
9691011
970- # 4 . Pretraining
1012+ # 5 . Pretraining
9711013 if not hasattr (self , "pretrain_" ):
9721014 if not hasattr (self , "pretrain" ):
9731015 self .pretrain_ = False
@@ -994,32 +1036,22 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
9941036 pre_epochs = prefit_params .pop ("epochs" , epochs )
9951037 pre_batch_size = prefit_params .pop ("batch_size" , batch_size )
9961038 pre_shuffle = prefit_params .pop ("shuffle" , shuffle )
1039+ prefit_params .pop ("validation_data" , None )
1040+ prefit_params .pop ("validation_split" , None )
1041+ prefit_params .pop ("validation_batch_size" , None )
9971042
9981043 if pre_shuffle :
999- dataset = tf .data .Dataset .zip ((dataset_X , dataset_y )).shuffle (buffer_size = 1024 ).batch (pre_batch_size )
1044+ dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt )).shuffle (buffer_size = 1024 ).batch (pre_batch_size )
10001045 else :
1001- dataset = tf .data .Dataset .zip ((dataset_X , dataset_y )).batch (pre_batch_size )
1046+ dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt )).batch (pre_batch_size )
10021047
1003- hist = super ().fit (dataset , epochs = pre_epochs , verbose = pre_verbose , ** prefit_params )
1048+ hist = super ().fit (dataset , validation_data = validation_data ,
1049+ epochs = pre_epochs , verbose = pre_verbose , ** prefit_params )
10041050
10051051 for k , v in hist .history .items ():
10061052 self .pretrain_history_ [k ] = self .pretrain_history_ .get (k , []) + v
10071053
10081054 self ._initialize_pretain_networks ()
1009-
1010- # 5. Define validation Set
1011- if isinstance (validation_data , tuple ):
1012- X_val = validation_data [0 ]
1013- y_val = validation_data [1 ]
1014-
1015- validation_data = tf .data .Dataset .zip (
1016- (tf .data .Dataset .from_tensor_slices (X_val ),
1017- tf .data .Dataset .from_tensor_slices (y_val ))
1018- )
1019- if shuffle :
1020- validation_data = validation_data .shuffle (buffer_size = 1024 ).batch (batch_size )
1021- else :
1022- validation_data = validation_data .batch (batch_size )
10231055
10241056 # 6. Training
10251057 if (not self ._is_compiled ) or (self .pretrain_ ):
@@ -1029,13 +1061,12 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
10291061 self .history_ = {}
10301062
10311063 if shuffle :
1032- dataset = tf .data .Dataset .zip ((dataset_X , dataset_y )).shuffle (buffer_size = 1024 ).batch (batch_size )
1064+ dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt )).shuffle (buffer_size = 1024 ).batch (batch_size )
10331065 else :
1034- dataset = tf .data .Dataset .zip ((dataset_X , dataset_y )).batch (batch_size )
1066+ dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt )).batch (batch_size )
10351067
10361068 self .pretrain_ = False
10371069 self .steps_ = tf .Variable (0. )
1038- self .total_steps_ = float (np .ceil (max_size / batch_size )* epochs )
10391070
10401071 hist = super ().fit (dataset , validation_data = validation_data , ** fit_params )
10411072
@@ -1227,10 +1258,6 @@ def train_step(self, data):
12271258 # Unpack the data.
12281259 Xs , Xt , ys , yt = self ._unpack_data (data )
12291260
1230- # Single source
1231- Xs = Xs [0 ]
1232- ys = ys [0 ]
1233-
12341261 # Run forward pass.
12351262 with tf .GradientTape () as tape :
12361263 y_pred = self (Xs , training = True )
@@ -1418,6 +1445,22 @@ def score_estimator(self, X, y, sample_weight=None):
14181445 if isinstance (score , (tuple , list )):
14191446 score = score [0 ]
14201447 return score
1448+
1449+
1450+ def _check_validation_data (self , validation_data , batch_size , shuffle ):
1451+ if isinstance (validation_data , tuple ):
1452+ X_val = validation_data [0 ]
1453+ y_val = validation_data [1 ]
1454+
1455+ validation_data = tf .data .Dataset .zip (
1456+ (tf .data .Dataset .from_tensor_slices (X_val ),
1457+ tf .data .Dataset .from_tensor_slices (y_val ))
1458+ )
1459+ if shuffle :
1460+ validation_data = validation_data .shuffle (buffer_size = 1024 ).batch (batch_size )
1461+ else :
1462+ validation_data = validation_data .batch (batch_size )
1463+ return validation_data
14211464
14221465
14231466 def _get_legal_params (self , params ):
@@ -1467,13 +1510,17 @@ def _initialize_weights(self, shape_X):
14671510
14681511
14691512 def _unpack_data (self , data ):
1470- data_X = data [0 ]
1471- data_y = data [1 ]
1472- Xs = data_X [:- 1 ]
1473- Xt = data_X [- 1 ]
1474- ys = data_y [:- 1 ]
1475- yt = data_y [- 1 ]
1476- return Xs , Xt , ys , ys
1513+ data_src = data [0 ]
1514+ data_tgt = data [1 ]
1515+ Xs = data_src [0 ]
1516+ ys = data_src [1 ]
1517+ if isinstance (data_tgt , tuple ):
1518+ Xt = data_tgt [0 ]
1519+ yt = data_tgt [1 ]
1520+ return Xs , Xt , ys , yt
1521+ else :
1522+ Xt = data_tgt
1523+ return Xs , Xt , ys , None
14771524
14781525
14791526 def _get_disc_metrics (self , ys_disc , yt_disc ):
0 commit comments