|
10 | 10 | from typing import List, Dict, Any, AsyncIterator, Tuple, Optional, Type |
11 | 11 |
|
12 | 12 | import numpy as np |
13 | | -import tensorflow |
| 13 | + |
| 14 | +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" |
| 15 | +import tensorflow as tf |
14 | 16 |
|
15 | 17 | from dffml.repo import Repo |
16 | | -from dffml.feature import Feature, Features |
17 | | -from dffml.source.source import Sources |
18 | | -from dffml.model.model import ModelConfig, ModelContext, Model, ModelNotTrained |
19 | | -from dffml.accuracy import Accuracy |
20 | | -from dffml.util.entrypoint import entrypoint |
21 | 18 | from dffml.base import BaseConfig |
22 | 19 | from dffml.util.cli.arg import Arg |
23 | | -from dffml.feature.feature import Feature, Features |
24 | | -from dffml.util.cli.parser import list_action |
| 20 | +from dffml.accuracy import Accuracy |
25 | 21 | from dffml.base import config, field |
| 22 | +from dffml.source.source import Sources |
| 23 | +from dffml.feature import Feature, Features |
| 24 | +from dffml.util.entrypoint import entrypoint |
| 25 | +from dffml.util.cli.parser import list_action |
| 26 | +from dffml.feature.feature import Feature, Features |
| 27 | +from dffml.model.model import ModelConfig, ModelContext, Model, ModelNotTrained |
26 | 28 |
|
27 | 29 |
|
28 | 30 | class TensorflowModelContext(ModelContext): |
@@ -65,7 +67,7 @@ def _feature_feature_column(self, feature: Feature): |
65 | 67 | or dtype is float |
66 | 68 | or issubclass(dtype, float) |
67 | 69 | ): |
68 | | - return tensorflow.feature_column.numeric_column( |
| 70 | + return tf.feature_column.numeric_column( |
69 | 71 | feature.NAME, shape=feature.length() |
70 | 72 | ) |
71 | 73 | self.logger.warning( |
@@ -112,7 +114,7 @@ async def predict_input_fn(self, repos: AsyncIterator[Repo], **kwargs): |
112 | 114 | self.logger.info("------ Repo Data ------") |
113 | 115 | self.logger.info("x_cols: %d", len(list(x_cols.values())[0])) |
114 | 116 | self.logger.info("-----------------------") |
115 | | - input_fn = tensorflow.estimator.inputs.numpy_input_fn( |
| 117 | + input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn( |
116 | 118 | x_cols, shuffle=False, num_epochs=1, **kwargs |
117 | 119 | ) |
118 | 120 | return input_fn, ret_repos |
@@ -206,7 +208,7 @@ def model(self): |
206 | 208 | len(self.classifications), |
207 | 209 | self.classifications, |
208 | 210 | ) |
209 | | - self._model = tensorflow.estimator.DNNClassifier( |
| 211 | + self._model = tf.compat.v1.estimator.DNNClassifier( |
210 | 212 | feature_columns=list(self.feature_columns.values()), |
211 | 213 | hidden_units=self.parent.config.hidden, |
212 | 214 | n_classes=len(self.parent.config.classifications), |
@@ -247,7 +249,7 @@ async def training_input_fn( |
247 | 249 | self.logger.info("x_cols: %d", len(list(x_cols.values())[0])) |
248 | 250 | self.logger.info("y_cols: %d", len(y_cols)) |
249 | 251 | self.logger.info("-----------------------") |
250 | | - input_fn = tensorflow.estimator.inputs.numpy_input_fn( |
| 252 | + input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn( |
251 | 253 | x_cols, |
252 | 254 | y_cols, |
253 | 255 | batch_size=self.parent.config.batchsize, |
@@ -287,7 +289,7 @@ async def accuracy_input_fn( |
287 | 289 | self.logger.info("x_cols: %d", len(list(x_cols.values())[0])) |
288 | 290 | self.logger.info("y_cols: %d", len(y_cols)) |
289 | 291 | self.logger.info("-----------------------") |
290 | | - input_fn = tensorflow.estimator.inputs.numpy_input_fn( |
| 292 | + input_fn = tf.compat.v1.estimator.inputs.numpy_input_fn( |
291 | 293 | x_cols, |
292 | 294 | y_cols, |
293 | 295 | batch_size=self.parent.config.batchsize, |
|
0 commit comments