Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit c637850

Browse files
authored
model: tensorflow: TF 1.X to 2.X
1 parent 7dd0fa8 commit c637850

File tree

6 files changed

+30
-27
lines changed

6 files changed

+30
-27
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
66

77
## [Unreleased]
88
### Added
9+
- Moved from TensorFlow 1 to TensorFlow 2.
910
- IDX Sources to read binary data files and train models on MNIST Dataset
1011
- scikit models
1112
- Clusterers

model/tensorflow/dffml_model_tensorflow/dnnc.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,19 +10,21 @@
1010
from typing import List, Dict, Any, AsyncIterator, Tuple, Optional, Type
1111

1212
import numpy as np
13-
import tensorflow
13+
14+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
15+
import tensorflow as tf
1416

1517
from dffml.repo import Repo
16-
from dffml.feature import Feature, Features
17-
from dffml.source.source import Sources
18-
from dffml.model.model import ModelConfig, ModelContext, Model, ModelNotTrained
19-
from dffml.accuracy import Accuracy
20-
from dffml.util.entrypoint import entrypoint
2118
from dffml.base import BaseConfig
2219
from dffml.util.cli.arg import Arg
23-
from dffml.feature.feature import Feature, Features
24-
from dffml.util.cli.parser import list_action
20+
from dffml.accuracy import Accuracy
2521
from dffml.base import config, field
22+
from dffml.source.source import Sources
23+
from dffml.feature import Feature, Features
24+
from dffml.util.entrypoint import entrypoint
25+
from dffml.util.cli.parser import list_action
26+
from dffml.feature.feature import Feature, Features
27+
from dffml.model.model import ModelConfig, ModelContext, Model, ModelNotTrained
2628

2729

2830
class TensorflowModelContext(ModelContext):
@@ -65,7 +67,7 @@ def _feature_feature_column(self, feature: Feature):
6567
or dtype is float
6668
or issubclass(dtype, float)
6769
):
68-
return tensorflow.feature_column.numeric_column(
70+
return tf.feature_column.numeric_column(
6971
feature.NAME, shape=feature.length()
7072
)
7173
self.logger.warning(
@@ -112,7 +114,7 @@ async def predict_input_fn(self, repos: AsyncIterator[Repo], **kwargs):
112114
self.logger.info("------ Repo Data ------")
113115
self.logger.info("x_cols: %d", len(list(x_cols.values())[0]))
114116
self.logger.info("-----------------------")
115-
input_fn = tensorflow.estimator.inputs.numpy_input_fn(
117+
input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
116118
x_cols, shuffle=False, num_epochs=1, **kwargs
117119
)
118120
return input_fn, ret_repos
@@ -206,7 +208,7 @@ def model(self):
206208
len(self.classifications),
207209
self.classifications,
208210
)
209-
self._model = tensorflow.estimator.DNNClassifier(
211+
self._model = tf.compat.v1.estimator.DNNClassifier(
210212
feature_columns=list(self.feature_columns.values()),
211213
hidden_units=self.parent.config.hidden,
212214
n_classes=len(self.parent.config.classifications),
@@ -247,7 +249,7 @@ async def training_input_fn(
247249
self.logger.info("x_cols: %d", len(list(x_cols.values())[0]))
248250
self.logger.info("y_cols: %d", len(y_cols))
249251
self.logger.info("-----------------------")
250-
input_fn = tensorflow.estimator.inputs.numpy_input_fn(
252+
input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
251253
x_cols,
252254
y_cols,
253255
batch_size=self.parent.config.batchsize,
@@ -287,7 +289,7 @@ async def accuracy_input_fn(
287289
self.logger.info("x_cols: %d", len(list(x_cols.values())[0]))
288290
self.logger.info("y_cols: %d", len(y_cols))
289291
self.logger.info("-----------------------")
290-
input_fn = tensorflow.estimator.inputs.numpy_input_fn(
292+
input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
291293
x_cols,
292294
y_cols,
293295
batch_size=self.parent.config.batchsize,

model/tensorflow/dffml_model_tensorflow/dnnr.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,21 @@
66
from typing import List, Dict, Any, AsyncIterator
77

88
import numpy as np
9-
import tensorflow
9+
10+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
11+
import tensorflow as tf
1012

1113
from dffml.repo import Repo
12-
from dffml.source.source import Sources
14+
from dffml.util.cli.arg import Arg
1315
from dffml.model.model import Model
1416
from dffml.accuracy import Accuracy
17+
from dffml.source.source import Sources
1518
from dffml.util.entrypoint import entrypoint
19+
from dffml.util.cli.parser import list_action
1620
from dffml.base import BaseConfig, config, field
17-
from dffml.util.cli.arg import Arg
1821
from dffml.feature.feature import Feature, Features
19-
from dffml.util.cli.parser import list_action
2022

21-
from dffml_model_tensorflow.dnnc import TensorflowModelContext
23+
from .dnnc import TensorflowModelContext
2224

2325

2426
@config
@@ -64,9 +66,7 @@ def model(self):
6466
return self._model
6567
self.logger.debug("Loading model ")
6668

67-
_head = tensorflow.contrib.estimator.regression_head()
68-
self._model = tensorflow.estimator.DNNEstimator(
69-
head=_head,
69+
self._model = tf.compat.v1.estimator.DNNRegressor(
7070
feature_columns=list(self.feature_columns.values()),
7171
hidden_units=self.parent.config.hidden,
7272
model_dir=self.model_dir_path,
@@ -102,7 +102,7 @@ async def training_input_fn(
102102
self.logger.info("x_cols: %d", len(list(x_cols.values())[0]))
103103
self.logger.info("y_cols: %d", len(y_cols))
104104
self.logger.info("-----------------------")
105-
input_fn = tensorflow.estimator.inputs.numpy_input_fn(
105+
input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
106106
x_cols,
107107
y_cols,
108108
batch_size=batch_size,
@@ -138,7 +138,7 @@ async def evaluate_input_fn(
138138
self.logger.info("x_cols: %d", len(list(x_cols.values())[0]))
139139
self.logger.info("y_cols: %d", len(y_cols))
140140
self.logger.info("-----------------------")
141-
input_fn = tensorflow.estimator.inputs.numpy_input_fn(
141+
input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn(
142142
x_cols,
143143
y_cols,
144144
batch_size=batch_size,

model/tensorflow/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
with open(os.path.join(self_path, "README.md"), "r", encoding="utf-8") as f:
1919
readme = f.read()
2020

21-
INSTALL_REQUIRES = ["tensorflow==1.14.0"] + (
21+
INSTALL_REQUIRES = ["tensorflow>=2.0.0"] + (
2222
["dffml>=0.3.1"]
2323
if not any(
2424
list(

model/tensorflow/tests/test_dnnc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def setUpClass(cls):
5454
DNNClassifierModelConfig(
5555
directory=cls.model_dir.name,
5656
steps=1000,
57-
epochs=30,
57+
epochs=40,
5858
hidden=[10, 20, 10],
5959
predict=DefFeature("string", str, 1),
6060
classifications=["a", "not a"],

model/tensorflow/tests/test_dnnr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,14 @@ def setUpClass(cls):
5050
DNNRegressionModelConfig(
5151
directory=cls.model_dir.name,
5252
steps=1000,
53-
epochs=30,
53+
epochs=40,
5454
hidden=[10, 20, 10],
5555
predict=DefFeature("TARGET", float, 1),
5656
features=cls.features,
5757
)
5858
)
5959
# Generating data f(x1,x2) = 2*x1 + 3*x2
60-
_n_data = 1000
60+
_n_data = 2000
6161
_temp_data = np.random.rand(2, _n_data)
6262
cls.repos = [
6363
Repo(

0 commit comments

Comments
 (0)