Skip to content

Commit 8368dd9

Browse files
authored
Upgrade tensorflow to 2.x (#41)
1 parent 1ddbeb3 commit 8368dd9

File tree

5 files changed

+19
-12
lines changed

5 files changed

+19
-12
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
language: python
22
python:
3-
- 3.6
3+
- 3.8
44
services:
55
- docker
66
install:

Dockerfile

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414
# limitations under the License.
1515
#
1616

17-
FROM quay.io/codait/max-base:v1.4.0
17+
FROM quay.io/codait/max-base:v1.5.1
1818

1919
COPY requirements.txt .
2020

2121
RUN pip install -r requirements.txt
22-
22+
2323
COPY . .
24-
24+
2525
EXPOSE 5000
2626

2727
CMD python app.py

api/predict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from core.model import ModelWrapper
1818
from maxfw.core import MAX_API, PredictAPI
19-
from flask_restplus import fields
19+
from flask_restx import fields
2020
from werkzeug.datastructures import FileStorage
2121
from config import DEFAULT_MODEL, MODELS
2222

core/model.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from config import DEFAULT_MODEL_PATH, MODELS, MODEL_META_DATA as model_meta
2121
from keras.models import load_model
2222
import numpy as np
23-
from sklearn.externals import joblib
23+
import joblib
2424

2525
logging.basicConfig()
2626
logger = logging.getLogger()
@@ -34,12 +34,19 @@ def load_array(input_data):
3434
class SingleModelWrapper(object):
3535

3636
def __init__(self, model, path):
37+
# The code was originally written for TF1 and doesn't work with eager mode.
38+
tf.compat.v1.disable_eager_execution()
39+
self.session = tf.compat.v1.Session()
40+
3741
self.model_name = model
3842

3943
# load model
4044
model_path = '{}/{}_model'.format(path, model)
4145
logger.info(model_path)
42-
self.graph = tf.get_default_graph()
46+
self.graph = tf.compat.v1.get_default_graph()
47+
# See https://github.com/tensorflow/tensorflow/issues/28287#issuecomment-495005162
48+
# We have to do this because we load 3 models in the process.
49+
tf.compat.v1.keras.backend.set_session(self.session)
4350
self.model = load_model(model_path)
4451

4552
# load scaler
@@ -96,6 +103,7 @@ def _rescale_preds(self, preds):
96103
def predict(self, x):
97104
reshaped_x = self._reshape_data(x)
98105
with self.graph.as_default():
106+
tf.compat.v1.keras.backend.set_session(self.session)
99107
preds = self.model.predict(reshaped_x)
100108
rescaled_preds = self._rescale_preds(preds)
101109
return rescaled_preds

requirements.txt

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
numpy==1.16.2
2-
tensorflow==1.15.2
3-
keras==2.2.4
4-
scikit-learn==0.22.1
5-
h5py==2.9.0
1+
tensorflow==2.6.0
2+
keras==2.6.0
3+
scikit-learn==0.22.2
4+
# numpy and h5py not specified here because tensorflow has specifies a major.minor version range for them

0 commit comments

Comments
 (0)