Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit 4d46339

Browse files
committed
model: tensorflow: Back out tensorflow 2.X support temporarily
Related: #374 Signed-off-by: John Andersen <[email protected]>
1 parent 323fd6d commit 4d46339

File tree

6 files changed

+19
-23
lines changed

6 files changed

+19
-23
lines changed

CHANGELOG.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
4040
to setuptools `setup` function within a `setup.py` file.
4141
### Changed
4242
- All instances of `src_url` changed to `key`.
43-
- Moved from tensorflow 1 to tensorflow 2.
4443
- `readonly` parameter in source config is now changed to `readwrite`.
4544
- `predict` parameter of all model config classes has been changed from `str` to `Feature`.
4645
- Defining features on the command line no longer requires that defined features

model/tensorflow/dffml_model_tensorflow/dnnc.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010
from typing import List, Dict, Any, AsyncIterator, Tuple, Optional, Type
1111

1212
import numpy as np
13-
14-
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
15-
import tensorflow as tf
13+
import tensorflow
1614

1715
from dffml.repo import Repo
1816
from dffml.feature import Feature, Features
@@ -67,7 +65,7 @@ def _feature_feature_column(self, feature: Feature):
6765
or dtype is float
6866
or issubclass(dtype, float)
6967
):
70-
return tf.feature_column.numeric_column(
68+
return tensorflow.feature_column.numeric_column(
7169
feature.NAME, shape=feature.length()
7270
)
7371
self.logger.warning(
@@ -114,7 +112,7 @@ async def predict_input_fn(self, repos: AsyncIterator[Repo], **kwargs):
114112
self.logger.info("------ Repo Data ------")
115113
self.logger.info("x_cols: %d", len(list(x_cols.values())[0]))
116114
self.logger.info("-----------------------")
117-
input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
115+
input_fn = tensorflow.estimator.inputs.numpy_input_fn(
118116
x_cols, shuffle=False, num_epochs=1, **kwargs
119117
)
120118
return input_fn, ret_repos
@@ -208,7 +206,7 @@ def model(self):
208206
len(self.classifications),
209207
self.classifications,
210208
)
211-
self._model = tf.estimator.DNNClassifier(
209+
self._model = tensorflow.estimator.DNNClassifier(
212210
feature_columns=list(self.feature_columns.values()),
213211
hidden_units=self.parent.config.hidden,
214212
n_classes=len(self.parent.config.classifications),
@@ -249,7 +247,7 @@ async def training_input_fn(
249247
self.logger.info("x_cols: %d", len(list(x_cols.values())[0]))
250248
self.logger.info("y_cols: %d", len(y_cols))
251249
self.logger.info("-----------------------")
252-
input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
250+
input_fn = tensorflow.estimator.inputs.numpy_input_fn(
253251
x_cols,
254252
y_cols,
255253
batch_size=self.parent.config.batchsize,
@@ -289,7 +287,7 @@ async def accuracy_input_fn(
289287
self.logger.info("x_cols: %d", len(list(x_cols.values())[0]))
290288
self.logger.info("y_cols: %d", len(y_cols))
291289
self.logger.info("-----------------------")
292-
input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
290+
input_fn = tensorflow.estimator.inputs.numpy_input_fn(
293291
x_cols,
294292
y_cols,
295293
batch_size=self.parent.config.batchsize,

model/tensorflow/dffml_model_tensorflow/dnnr.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,19 @@
66
from typing import List, Dict, Any, AsyncIterator
77

88
import numpy as np
9-
10-
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
11-
import tensorflow as tf
9+
import tensorflow
1210

1311
from dffml.repo import Repo
14-
from dffml.util.cli.arg import Arg
12+
from dffml.source.source import Sources
1513
from dffml.model.model import Model
1614
from dffml.accuracy import Accuracy
17-
from dffml.source.source import Sources
1815
from dffml.util.entrypoint import entrypoint
19-
from dffml.util.cli.parser import list_action
2016
from dffml.base import BaseConfig, config, field
17+
from dffml.util.cli.arg import Arg
2118
from dffml.feature.feature import Feature, Features
19+
from dffml.util.cli.parser import list_action
2220

23-
from .dnnc import TensorflowModelContext
21+
from dffml_model_tensorflow.dnnc import TensorflowModelContext
2422

2523

2624
@config
@@ -65,8 +63,9 @@ def model(self):
6563
if self._model is not None:
6664
return self._model
6765
self.logger.debug("Loading model ")
68-
_head = tf.estimator.RegressionHead()
69-
self._model = tf.estimator.DNNEstimator(
66+
67+
_head = tensorflow.contrib.estimator.regression_head()
68+
self._model = tensorflow.estimator.DNNEstimator(
7069
head=_head,
7170
feature_columns=list(self.feature_columns.values()),
7271
hidden_units=self.parent.config.hidden,
@@ -103,7 +102,7 @@ async def training_input_fn(
103102
self.logger.info("x_cols: %d", len(list(x_cols.values())[0]))
104103
self.logger.info("y_cols: %d", len(y_cols))
105104
self.logger.info("-----------------------")
106-
input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
105+
input_fn = tensorflow.estimator.inputs.numpy_input_fn(
107106
x_cols,
108107
y_cols,
109108
batch_size=batch_size,
@@ -139,7 +138,7 @@ async def evaluate_input_fn(
139138
self.logger.info("x_cols: %d", len(list(x_cols.values())[0]))
140139
self.logger.info("y_cols: %d", len(y_cols))
141140
self.logger.info("-----------------------")
142-
input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
141+
input_fn = tensorflow.estimator.inputs.numpy_input_fn(
143142
x_cols,
144143
y_cols,
145144
batch_size=batch_size,

model/tensorflow/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
with open(os.path.join(self_path, "README.md"), "r", encoding="utf-8") as f:
1919
readme = f.read()
2020

21-
INSTALL_REQUIRES = ["tensorflow>=2.1.0rc2"] + (
21+
INSTALL_REQUIRES = ["tensorflow==1.14.0"] + (
2222
["dffml>=0.3.1"]
2323
if not any(
2424
list(

model/tensorflow/tests/test_dnnc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def setUpClass(cls):
5555
directory=cls.model_dir.name,
5656
steps=1000,
5757
epochs=30,
58-
hidden=[300, 200, 80, 10],
58+
hidden=[10, 20, 10],
5959
predict=DefFeature("string", str, 1),
6060
classifications=["a", "not a"],
6161
clstype=str,

model/tensorflow/tests/test_dnnr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def setUpClass(cls):
5151
directory=cls.model_dir.name,
5252
steps=1000,
5353
epochs=30,
54-
hidden=[200, 100, 80, 10],
54+
hidden=[10, 20, 10],
5555
predict=DefFeature("TARGET", float, 1),
5656
features=cls.features,
5757
)

0 commit comments

Comments
 (0)