@@ -897,41 +897,64 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
897897 self ._is_fitted = True
898898 self ._initialize_networks ()
899899 self ._initialize_weights (X .shape [1 :])
900+
901+ # 2. Get Fit params
902+ fit_params = self ._filter_params (super ().fit , fit_params )
900903
901- # 2. Prepare dataset
904+ verbose = fit_params .get ("verbose" , 1 )
905+ epochs = fit_params .get ("epochs" , 1 )
906+ batch_size = fit_params .pop ("batch_size" , 32 )
907+ shuffle = fit_params .pop ("shuffle" , True )
908+ validation_data = fit_params .pop ("validation_data" , None )
909+ validation_split = fit_params .pop ("validation_split" , 0. )
910+ validation_batch_size = fit_params .pop ("validation_batch_size" , batch_size )
911+
912+ # 3. Prepare dataset
913+
914+ ### 3.1 Source
915+ if not isinstance (X , tf .data .Dataset ):
916+ check_arrays (X , y )
917+ if len (y .shape ) <= 1 :
918+ y = y .reshape (- 1 , 1 )
919+
920+ ### 3.2 Target
902921 Xt , yt = self ._get_target_data (Xt , yt )
903-
904- check_arrays (X , y )
905- if len (y .shape ) <= 1 :
906- y = y .reshape (- 1 , 1 )
922+ if not isinstance (Xt , tf .data .Dataset ):
923+ if yt is None :
924+ yt = y
925+ check_array (Xt , ensure_2d = True , allow_nd = True )
926+ else :
927+ check_arrays (Xt , yt )
907928
908- if yt is None :
909- yt = y
910- check_array (Xt , ensure_2d = True , allow_nd = True )
911- else :
912- check_arrays (Xt , yt )
913-
914- if len (yt .shape ) <= 1 :
915- yt = yt .reshape (- 1 , 1 )
929+ if len (yt .shape ) <= 1 :
930+ yt = yt .reshape (- 1 , 1 )
916931
917932 self ._save_validation_data (X , Xt )
918933
934+ ### 3.3 Domains
919935 domains = fit_params .pop ("domains" , None )
920-
936+
921937 if domains is None :
922938 domains = np .zeros (len (X ))
923-
939+
924940 domains = self ._check_domains (domains )
925941
926942 self .n_sources_ = int (np .max (domains )+ 1 )
927-
943+
928944 sizes = np .array (
929945 [np .sum (domains == dom ) for dom in range (self .n_sources_ )]+
930946 [len (Xt )])
931-
947+
932948 max_size = np .max (sizes )
933949 repeats = np .ceil (max_size / sizes )
934950
951+ # Split if validation_split
952+ # if validation_data is None and validation_split>0.:
953+ # frac = int(len(dataset)*validation_split)
954+ # validation_data = dataset.take(frac)
955+ # dataset = dataset.skip(frac)
956+
957+
935958 dataset_X = tf .data .Dataset .zip (tuple (
936959 tf .data .Dataset .from_tensor_slices (X [domains == dom ]).repeat (repeats [dom ])
937960 for dom in range (self .n_sources_ ))+
@@ -944,15 +967,6 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
944967 (tf .data .Dataset .from_tensor_slices (yt ).repeat (repeats [- 1 ]),)
945968 )
946969
947-
948- # 3. Get Fit params
949- fit_params = self ._filter_params (super ().fit , fit_params )
950-
951- verbose = fit_params .get ("verbose" , 1 )
952- epochs = fit_params .get ("epochs" , 1 )
953- batch_size = fit_params .pop ("batch_size" , 32 )
954- shuffle = fit_params .pop ("shuffle" , True )
955-
956970 # 4. Pretraining
957971 if not hasattr (self , "pretrain_" ):
958972 if not hasattr (self , "pretrain" ):
@@ -993,7 +1007,21 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
9931007
9941008 self ._initialize_pretain_networks ()
9951009
996- # 5. Training
1010+ # 5. Define validation Set
1011+ if isinstance (validation_data , tuple ):
1012+ X_val = validation_data [0 ]
1013+ y_val = validation_data [1 ]
1014+
1015+ validation_data = tf .data .Dataset .zip (
1016+ (tf .data .Dataset .from_tensor_slices (X_val ),
1017+ tf .data .Dataset .from_tensor_slices (y_val ))
1018+ )
1019+ if shuffle :
1020+ validation_data = validation_data .shuffle (buffer_size = 1024 ).batch (batch_size )
1021+ else :
1022+ validation_data = validation_data .batch (batch_size )
1023+
1024+ # 6. Training
9971025 if (not self ._is_compiled ) or (self .pretrain_ ):
9981026 self .compile ()
9991027
@@ -1004,12 +1032,12 @@ def fit(self, X, y, Xt=None, yt=None, domains=None, **fit_params):
10041032 dataset = tf .data .Dataset .zip ((dataset_X , dataset_y )).shuffle (buffer_size = 1024 ).batch (batch_size )
10051033 else :
10061034 dataset = tf .data .Dataset .zip ((dataset_X , dataset_y )).batch (batch_size )
1007-
1035+
10081036 self .pretrain_ = False
10091037 self .steps_ = tf .Variable (0. )
10101038 self .total_steps_ = float (np .ceil (max_size / batch_size )* epochs )
10111039
1012- hist = super ().fit (dataset , ** fit_params )
1040+ hist = super ().fit (dataset , validation_data = validation_data , ** fit_params )
10131041
10141042 for k , v in hist .history .items ():
10151043 self .history_ [k ] = self .history_ .get (k , []) + v
0 commit comments