Skip to content

Commit f51fb01

Browse files
committed
Fix model load .
1 parent 948cdba commit f51fb01

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

deeptables/models/deepmodel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def __init__(self,
2828
config,
2929
categorical_columns,
3030
continuous_columns,
31-
var_categorical_len_columns=None, # Compatible persisted model
32-
model_file=None):
31+
model_file=None,
32+
var_categorical_len_columns=None, ):
3333

3434
# set gpu usage strategy before build model
3535
if config.gpu_usage_strategy == consts.GPU_USAGE_STRATEGY_GROWTH:

deeptables/models/deeptable.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,8 @@ def fit(self, X=None, y=None, batch_size=128, epochs=1, verbose=1, callbacks=Non
336336
max_queue_size=10, workers=1, use_multiprocessing=False):
337337
logger.info(f'X.Shape={np.shape(X)}, y.Shape={np.shape(y)}, batch_size={batch_size}, config={self.config}')
338338
logger.info(f'metrics:{self.config.metrics}')
339+
if np.ndim(X) != 2:
340+
raise ValueError("Input train data should be 2d .")
339341
X_shape = np.shape(X)
340342
if X_shape[1] < 1:
341343
raise ValueError("Input train data should has 1 feature at least.")
@@ -353,7 +355,7 @@ def fit(self, X=None, y=None, batch_size=128, epochs=1, verbose=1, callbacks=Non
353355
model = DeepModel(self.task, self.num_classes, self.config,
354356
self.preprocessor.categorical_columns,
355357
self.preprocessor.continuous_columns,
356-
self.preprocessor.var_len_categorical_columns)
358+
var_categorical_len_columns=self.preprocessor.var_len_categorical_columns)
357359
history = model.fit(X, y, batch_size=batch_size, epochs=epochs, verbose=verbose, shuffle=shuffle,
358360
validation_split=validation_split, validation_data=validation_data,
359361
validation_steps=validation_steps, validation_freq=validation_freq,
@@ -706,7 +708,8 @@ def load_deepmodel(self, filepath):
706708
if os.path.exists(filepath):
707709
print(f'Load model from disk:{filepath}.')
708710
dm = DeepModel(self.task, self.num_classes, self.config,
709-
self.preprocessor.categorical_columns, self.preprocessor.continuous_columns, filepath)
711+
self.preprocessor.categorical_columns,
712+
self.preprocessor.continuous_columns, model_file=filepath)
710713
return dm
711714
else:
712715
raise ValueError(f'Invalid model filename:{filepath}.')

0 commit comments

Comments
 (0)