Skip to content

Commit 948cdba

Browse files
committed
Merge branch 'master' of github.com:DataCanvasIO/DeepTables into dev_wuhf
2 parents 1ed27c8 + 37fb8e8 commit 948cdba

File tree

5 files changed

+22
-4
lines changed

5 files changed

+22
-4
lines changed

deeptables/models/deeptable.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from .deepmodel import DeepModel
2323
from .preprocessor import DefaultPreprocessor
2424
from ..utils import dt_logging, consts
25+
from ..utils.tf_version import tf_less_than
2526

2627
logger = dt_logging.get_logger()
2728

@@ -648,11 +649,11 @@ def __inject_callbacks(self, callbacks):
648649
# callbacks.append(mcp)
649650
# print(f'Injected a callback [ModelCheckpoint].\nfilepath:{mcp.filepath}\nmonitor:{mcp.monitor}')
650651
if es is None:
651-
es = EarlyStopping(monitor=self.monitor,
652+
es = EarlyStopping(monitor=self.monitor if tf_less_than('2.2') else self.monitor.lower(),
652653
restore_best_weights=True,
653654
patience=self.config.earlystopping_patience,
654655
verbose=1,
655-
#min_delta=0.0001,
656+
# min_delta=0.0001,
656657
mode=mode,
657658
baseline=None,
658659
)

deeptables/utils/tf_version.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# -*- coding:utf-8 -*-
2+
__author__ = 'yangjian'
3+
"""
4+
5+
"""
6+
import tensorflow as tf
7+
from packaging.version import parse
8+
9+
10+
def tf_less_than(version):
11+
return parse(tf.__version__) < parse(version)
12+
13+
14+
def tf_greater_than(version):
15+
return parse(tf.__version__) > parse(version)

docs/source/examples.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Exapmles
1+
# Examples
22

33
## Binary Classification
44

@@ -78,4 +78,4 @@ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_
7878
model, history = dt.fit(X_train, y_train, epochs=100)
7979

8080
score = dt.evaluate(X_test, y_test)
81-
```
81+
```

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ scikit-optimize==0.7.1
99
tables==3.6.1
1010
category_encoders==2.1.0
1111
hypernets>=0.1.2
12+
h5py==2.10.0
1213
eli5

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
'category_encoders==2.1.0',
2020
'tables==3.6.1',
2121
'hypernets>=0.1.2',
22+
'h5py==2.10.0',
2223
'eli5',
2324
]
2425

0 commit comments

Comments
 (0)