@@ -336,6 +336,8 @@ def fit(self, X=None, y=None, batch_size=128, epochs=1, verbose=1, callbacks=Non
336336 max_queue_size = 10 , workers = 1 , use_multiprocessing = False ):
337337 logger .info (f'X.Shape={ np .shape (X )} , y.Shape={ np .shape (y )} , batch_size={ batch_size } , config={ self .config } ' )
338338 logger .info (f'metrics:{ self .config .metrics } ' )
339+ if np .ndim (X ) != 2 :
340+ raise ValueError ("Input train data should be 2d ." )
339341 X_shape = np .shape (X )
340342 if X_shape [1 ] < 1 :
341343 raise ValueError ("Input train data should has 1 feature at least." )
@@ -353,7 +355,7 @@ def fit(self, X=None, y=None, batch_size=128, epochs=1, verbose=1, callbacks=Non
353355 model = DeepModel (self .task , self .num_classes , self .config ,
354356 self .preprocessor .categorical_columns ,
355357 self .preprocessor .continuous_columns ,
356- self .preprocessor .var_len_categorical_columns )
358+ var_categorical_len_columns = self .preprocessor .var_len_categorical_columns )
357359 history = model .fit (X , y , batch_size = batch_size , epochs = epochs , verbose = verbose , shuffle = shuffle ,
358360 validation_split = validation_split , validation_data = validation_data ,
359361 validation_steps = validation_steps , validation_freq = validation_freq ,
@@ -706,7 +708,8 @@ def load_deepmodel(self, filepath):
706708 if os .path .exists (filepath ):
707709 print (f'Load model from disk:{ filepath } .' )
708710 dm = DeepModel (self .task , self .num_classes , self .config ,
709- self .preprocessor .categorical_columns , self .preprocessor .continuous_columns , filepath )
711+ self .preprocessor .categorical_columns ,
712+ self .preprocessor .continuous_columns , model_file = filepath )
710713 return dm
711714 else :
712715 raise ValueError (f'Invalid model filename:{ filepath } .' )
0 commit comments