@@ -367,7 +367,7 @@ def _save_validation_data(self, Xs, Xt):
367367 else :
368368 self .Xs_ = Xs
369369 self .Xt_ = Xt
370- self .src_index_ = np . arange ( len ( Xs ))
370+ self .src_index_ = None
371371
372372
373373 def _get_target_data (self , X , y ):
@@ -857,7 +857,7 @@ def __init__(self,
857857 self ._self_setattr_tracking = True
858858
859859
860- def fit (self , X , y , Xt = None , yt = None , domains = None , ** fit_params ):
860+ def fit (self , X , y = None , Xt = None , yt = None , domains = None , ** fit_params ):
861861 """
862862 Fit Model. Note that ``fit`` does not reset
863863 the model but extend the training.
@@ -867,7 +867,7 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
867867 X : array or Tensor
868868 Source input data.
869869
870- y : array or Tensor
870+ y : array or Tensor (default=None)
871871 Source output data.
872872
873873 Xt : array (default=None)
@@ -896,64 +896,120 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
896896 if not hasattr (self , "_is_fitted" ):
897897 self ._is_fitted = True
898898 self ._initialize_networks ()
899- self ._initialize_weights (X .shape [1 :])
899+ if isinstance (X , tf .data .Dataset ):
900+ first_elem = next (iter (X ))
901+ if (not isinstance (first_elem , tuple ) or
902+ not len (first_elem )== 2 ):
903+ raise ValueError ("When first argument is a dataset. "
904+ "It should return (x, y) tuples." )
905+ else :
906+ shape = first_elem [0 ].shape
907+ else :
908+ shape = X .shape [1 :]
909+ print (shape )
910+ self ._initialize_weights (shape )
911+
912+ # 2. Get Fit params
913+ fit_params = self ._filter_params (super ().fit , fit_params )
914+
915+ verbose = fit_params .get ("verbose" , 1 )
916+ epochs = fit_params .get ("epochs" , 1 )
917+ batch_size = fit_params .pop ("batch_size" , 32 )
918+ shuffle = fit_params .pop ("shuffle" , True )
919+ validation_data = fit_params .pop ("validation_data" , None )
920+ validation_split = fit_params .pop ("validation_split" , 0. )
921+ validation_batch_size = fit_params .pop ("validation_batch_size" , batch_size )
922+
923+ # 3. Prepare datasets
900924
901- # 2. Prepare dataset
925+ ### 3.1 Source
926+ if not isinstance (X , tf .data .Dataset ):
927+ check_arrays (X , y )
928+ if len (y .shape ) <= 1 :
929+ y = y .reshape (- 1 , 1 )
930+
931+ # Single source
932+ if domains is None :
933+ self .n_sources_ = 1
934+
935+ dataset_Xs = tf .data .Dataset .from_tensor_slices (X )
936+ dataset_ys = tf .data .Dataset .from_tensor_slices (y )
937+
938+ # Multisource
939+ else :
940+ domains = self ._check_domains (domains )
941+ self .n_sources_ = int (np .max (domains )+ 1 )
942+
943+ sizes = [np .sum (domains == dom )
944+ for dom in range (self .n_sources_ )]
945+
946+ max_size = np .max (sizes )
947+ repeats = np .ceil (max_size / sizes )
948+
949+ dataset_Xs = tf .data .Dataset .zip (tuple (
950+ tf .data .Dataset .from_tensor_slices (X [domains == dom ]).repeat (repeats [dom ])
951+ for dom in range (self .n_sources_ ))
952+ )
953+
954+ dataset_ys = tf .data .Dataset .zip (tuple (
955+ tf .data .Dataset .from_tensor_slices (y [domains == dom ]).repeat (repeats [dom ])
956+ for dom in range (self .n_sources_ ))
957+ )
958+
959+ dataset_src = tf .data .Dataset .zip ((dataset_Xs , dataset_ys ))
960+
961+ else :
962+ dataset_src = X
963+
964+ ### 3.2 Target
902965 Xt , yt = self ._get_target_data (Xt , yt )
966+ if not isinstance (Xt , tf .data .Dataset ):
967+ if yt is None :
968+ check_array (Xt , ensure_2d = True , allow_nd = True )
969+ dataset_tgt = tf .data .Dataset .from_tensor_slices (Xt )
903970
904- check_arrays (X , y )
905- if len (y .shape ) <= 1 :
906- y = y .reshape (- 1 , 1 )
971+ else :
972+ check_arrays (Xt , yt )
973+
974+ if len (yt .shape ) <= 1 :
975+ yt = yt .reshape (- 1 , 1 )
976+
977+ dataset_Xt = tf .data .Dataset .from_tensor_slices (Xt )
978+ dataset_yt = tf .data .Dataset .from_tensor_slices (yt )
979+ dataset_tgt = tf .data .Dataset .zip ((dataset_Xt , dataset_yt ))
907980
908- if yt is None :
909- yt = y
910- check_array (Xt , ensure_2d = True , allow_nd = True )
911981 else :
912- check_arrays (Xt , yt )
913-
914- if len (yt .shape ) <= 1 :
915- yt = yt .reshape (- 1 , 1 )
982+ dataset_tgt = Xt
916983
917984 self ._save_validation_data (X , Xt )
918985
919- domains = fit_params .pop ("domains" , None )
920-
921- if domains is None :
922- domains = np .zeros (len (X ))
923-
924- domains = self ._check_domains (domains )
925-
926- self .n_sources_ = int (np .max (domains )+ 1 )
927-
928- sizes = np .array (
929- [np .sum (domains == dom ) for dom in range (self .n_sources_ )]+
930- [len (Xt )])
931-
932- max_size = np .max (sizes )
933- repeats = np .ceil (max_size / sizes )
934-
935- dataset_X = tf .data .Dataset .zip (tuple (
936- tf .data .Dataset .from_tensor_slices (X [domains == dom ]).repeat (repeats [dom ])
937- for dom in range (self .n_sources_ ))+
938- (tf .data .Dataset .from_tensor_slices (Xt ).repeat (repeats [- 1 ]),)
939- )
940-
941- dataset_y = tf .data .Dataset .zip (tuple (
942- tf .data .Dataset .from_tensor_slices (y [domains == dom ]).repeat (repeats [dom ])
943- for dom in range (self .n_sources_ ))+
944- (tf .data .Dataset .from_tensor_slices (yt ).repeat (repeats [- 1 ]),)
945- )
946-
947-
948- # 3. Get Fit params
949- fit_params = self ._filter_params (super ().fit , fit_params )
950-
951- verbose = fit_params .get ("verbose" , 1 )
952- epochs = fit_params .get ("epochs" , 1 )
953- batch_size = fit_params .pop ("batch_size" , 32 )
954- shuffle = fit_params .pop ("shuffle" , True )
986+ # 4. Get validation data
987+ validation_data = self ._check_validation_data (validation_data ,
988+ validation_batch_size ,
989+ shuffle )
990+
991+ if validation_data is None and validation_split > 0. :
992+ if shuffle :
993+ dataset_src = dataset_src .shuffle (buffer_size = 1024 )
994+ frac = int (len (dataset_src )* validation_split )
995+ validation_data = dataset_src .take (frac )
996+ dataset_src = dataset_src .skip (frac )
997+ validation_data = validation_data .batch (batch_size )
998+
999+ # 5. Set datasets
1000+ try :
1001+ max_size = max (len (dataset_src ), len (dataset_tgt ))
1002+ repeat_src = np .ceil (max_size / len (dataset_src ))
1003+ repeat_tgt = np .ceil (max_size / len (dataset_tgt ))
1004+
1005+ dataset_src = dataset_src .repeat (repeat_src )
1006+ dataset_tgt = dataset_tgt .repeat (repeat_tgt )
1007+
1008+ self .total_steps_ = float (np .ceil (max_size / batch_size )* epochs )
1009+ except :
1010+ pass
9551011
956- # 4 . Pretraining
1012+ # 5 . Pretraining
9571013 if not hasattr (self , "pretrain_" ):
9581014 if not hasattr (self , "pretrain" ):
9591015 self .pretrain_ = False
@@ -980,36 +1036,39 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
9801036 pre_epochs = prefit_params .pop ("epochs" , epochs )
9811037 pre_batch_size = prefit_params .pop ("batch_size" , batch_size )
9821038 pre_shuffle = prefit_params .pop ("shuffle" , shuffle )
1039+ prefit_params .pop ("validation_data" , None )
1040+ prefit_params .pop ("validation_split" , None )
1041+ prefit_params .pop ("validation_batch_size" , None )
9831042
9841043 if pre_shuffle :
985- dataset = tf .data .Dataset .zip ((dataset_X , dataset_y )).shuffle (buffer_size = 1024 ).batch (pre_batch_size )
1044+ dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt )).shuffle (buffer_size = 1024 ).batch (pre_batch_size )
9861045 else :
987- dataset = tf .data .Dataset .zip ((dataset_X , dataset_y )).batch (pre_batch_size )
1046+ dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt )).batch (pre_batch_size )
9881047
989- hist = super ().fit (dataset , epochs = pre_epochs , verbose = pre_verbose , ** prefit_params )
1048+ hist = super ().fit (dataset , validation_data = validation_data ,
1049+ epochs = pre_epochs , verbose = pre_verbose , ** prefit_params )
9901050
9911051 for k , v in hist .history .items ():
9921052 self .pretrain_history_ [k ] = self .pretrain_history_ .get (k , []) + v
9931053
9941054 self ._initialize_pretain_networks ()
995-
996- # 5 . Training
1055+
1056+ # 6 . Training
9971057 if (not self ._is_compiled ) or (self .pretrain_ ):
9981058 self .compile ()
9991059
10001060 if not hasattr (self , "history_" ):
10011061 self .history_ = {}
10021062
10031063 if shuffle :
1004- dataset = tf .data .Dataset .zip ((dataset_X , dataset_y )).shuffle (buffer_size = 1024 ).batch (batch_size )
1064+ dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt )).shuffle (buffer_size = 1024 ).batch (batch_size )
10051065 else :
1006- dataset = tf .data .Dataset .zip ((dataset_X , dataset_y )).batch (batch_size )
1007-
1066+ dataset = tf .data .Dataset .zip ((dataset_src , dataset_tgt )).batch (batch_size )
1067+
10081068 self .pretrain_ = False
10091069 self .steps_ = tf .Variable (0. )
1010- self .total_steps_ = float (np .ceil (max_size / batch_size )* epochs )
10111070
1012- hist = super ().fit (dataset , ** fit_params )
1071+ hist = super ().fit (dataset , validation_data = validation_data , ** fit_params )
10131072
10141073 for k , v in hist .history .items ():
10151074 self .history_ [k ] = self .history_ .get (k , []) + v
@@ -1199,10 +1258,6 @@ def train_step(self, data):
11991258 # Unpack the data.
12001259 Xs , Xt , ys , yt = self ._unpack_data (data )
12011260
1202- # Single source
1203- Xs = Xs [0 ]
1204- ys = ys [0 ]
1205-
12061261 # Run forward pass.
12071262 with tf .GradientTape () as tape :
12081263 y_pred = self (Xs , training = True )
@@ -1390,6 +1445,22 @@ def score_estimator(self, X, y, sample_weight=None):
13901445 if isinstance (score , (tuple , list )):
13911446 score = score [0 ]
13921447 return score
1448+
1449+
1450+ def _check_validation_data (self , validation_data , batch_size , shuffle ):
1451+ if isinstance (validation_data , tuple ):
1452+ X_val = validation_data [0 ]
1453+ y_val = validation_data [1 ]
1454+
1455+ validation_data = tf .data .Dataset .zip (
1456+ (tf .data .Dataset .from_tensor_slices (X_val ),
1457+ tf .data .Dataset .from_tensor_slices (y_val ))
1458+ )
1459+ if shuffle :
1460+ validation_data = validation_data .shuffle (buffer_size = 1024 ).batch (batch_size )
1461+ else :
1462+ validation_data = validation_data .batch (batch_size )
1463+ return validation_data
13931464
13941465
13951466 def _get_legal_params (self , params ):
@@ -1439,13 +1510,17 @@ def _initialize_weights(self, shape_X):
14391510
14401511
14411512 def _unpack_data (self , data ):
1442- data_X = data [0 ]
1443- data_y = data [1 ]
1444- Xs = data_X [:- 1 ]
1445- Xt = data_X [- 1 ]
1446- ys = data_y [:- 1 ]
1447- yt = data_y [- 1 ]
1448- return Xs , Xt , ys , ys
1513+ data_src = data [0 ]
1514+ data_tgt = data [1 ]
1515+ Xs = data_src [0 ]
1516+ ys = data_src [1 ]
1517+ if isinstance (data_tgt , tuple ):
1518+ Xt = data_tgt [0 ]
1519+ yt = data_tgt [1 ]
1520+ return Xs , Xt , ys , yt
1521+ else :
1522+ Xt = data_tgt
1523+ return Xs , Xt , ys , None
14491524
14501525
14511526 def _get_disc_metrics (self , ys_disc , yt_disc ):
0 commit comments