Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit d88b3fe

Browse files
authored
model: tensorflow: Refactor
* Add TensorflowBaseConfig * Fix initialisation order * Add a method to convert sources to numpy array * Move common code in predict to TensorflowModelContext Fixes: #383 Signed-off-by: John Andersen <[email protected]>
1 parent 7f877c3 commit d88b3fe

File tree

3 files changed

+73
-100
lines changed

3 files changed

+73
-100
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3131
- Windows support by selecting `asyncio.ProactorEventLoop` and not using
3232
`asyncio.FastChildWatcher`.
3333
- Moved SLR into the main dffml package and removed `scratch:slr`.
34+
### Changed
35+
- Refactor `model/tensroflow`
3436

3537
## [0.3.5] - 2020-03-10
3638
### Added

model/tensorflow/dffml_model_tensorflow/dnnc.py

Lines changed: 46 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import hashlib
88
import inspect
99
import pathlib
10+
from dataclasses import dataclass
1011
from typing import List, Dict, Any, AsyncIterator, Type
1112

1213
import numpy as np
@@ -23,6 +24,24 @@
2324
from dffml.model.model import ModelContext, Model, ModelNotTrained
2425

2526

27+
@dataclass
28+
class TensorflowBaseConfig:
29+
predict: Feature = field("Feature name holding target values")
30+
features: Features = field("Features to train on")
31+
steps: int = field("Number of steps to train the model", default=3000)
32+
epochs: int = field(
33+
"Number of iterations to pass over all records in a source", default=30
34+
)
35+
directory: pathlib.Path = field(
36+
"Directory where state should be saved",
37+
default=pathlib.Path("~", ".cache", "dffml", "tensorflow"),
38+
)
39+
hidden: List[int] = field(
40+
"List length is the number of hidden layers in the network. Each entry in the list is the number of nodes in that hidden layer",
41+
default_factory=lambda: [12, 40, 15],
42+
)
43+
44+
2645
class TensorflowModelContext(ModelContext):
2746
"""
2847
Tensorflow based model contexts should derive from this model context. As it
@@ -122,6 +141,17 @@ async def train(self, sources: Sources):
122141
input_fn = await self.training_input_fn(sources)
123142
self.model.train(input_fn=input_fn, steps=self.parent.config.steps)
124143

144+
async def get_predictions(self, records: Record):
145+
if not os.path.isdir(self.model_dir_path):
146+
raise ModelNotTrained("Train model before prediction.")
147+
# Create the input function
148+
input_fn, predict = await self.predict_input_fn(records)
149+
# Makes predictions on classifications
150+
predictions = self.model.predict(input_fn=input_fn)
151+
target = self.parent.config.predict.NAME
152+
153+
return predict, predictions, target
154+
125155
@property
126156
@abc.abstractmethod
127157
def model(self):
@@ -131,29 +161,17 @@ def model(self):
131161

132162

133163
@config
134-
class DNNClassifierModelConfig:
135-
predict: Feature = field("Feature name holding predict value")
136-
classifications: List[str] = field("Options for value of classification")
137-
features: Features = field("Features to train on")
164+
class DNNClassifierModelConfig(TensorflowBaseConfig):
165+
classifications: List[str] = field(
166+
"Options for value of classification", default=None
167+
)
138168
clstype: Type = field("Data type of classifications values", default=str)
139169
batchsize: int = field(
140170
"Number records to pass through in an epoch", default=20
141171
)
142172
shuffle: bool = field(
143173
"Randomise order of records in a batch", default=True
144174
)
145-
steps: int = field("Number of steps to train the model", default=3000)
146-
epochs: int = field(
147-
"Number of iterations to pass over all records in a source", default=30
148-
)
149-
directory: pathlib.Path = field(
150-
"Directory where state should be saved",
151-
default=pathlib.Path("~", ".cache", "dffml", "tensorflow"),
152-
)
153-
hidden: List[int] = field(
154-
"List length is the number of hidden layers in the network. Each entry in the list is the number of nodes in that hidden layer",
155-
default_factory=lambda: [12, 40, 15],
156-
)
157175

158176
def __post_init__(self):
159177
self.classifications = list(map(self.clstype, self.classifications))
@@ -212,11 +230,7 @@ def model(self):
212230
)
213231
return self._model
214232

215-
async def training_input_fn(self, sources: Sources, **kwargs):
216-
"""
217-
Uses the numpy input function with data from record features.
218-
"""
219-
self.logger.debug("Training on features: %r", self.features)
233+
async def sources_to_array(self, sources: Sources):
220234
x_cols: Dict[str, Any] = {feature: [] for feature in self.features}
221235
y_cols = []
222236
for record in [
@@ -239,6 +253,15 @@ async def training_input_fn(self, sources: Sources, **kwargs):
239253
y_cols = np.array(y_cols)
240254
for feature in x_cols:
241255
x_cols[feature] = np.array(x_cols[feature])
256+
257+
return x_cols, y_cols
258+
259+
async def training_input_fn(self, sources: Sources, **kwargs):
260+
"""
261+
Uses the numpy input function with data from record features.
262+
"""
263+
self.logger.debug("Training on features: %r", self.features)
264+
x_cols, y_cols = await self.sources_to_array(sources)
242265
self.logger.info("------ Record Data ------")
243266
self.logger.info("x_cols: %d", len(list(x_cols.values())[0]))
244267
self.logger.info("y_cols: %d", len(y_cols))
@@ -257,26 +280,7 @@ async def accuracy_input_fn(self, sources: Sources, **kwargs):
257280
"""
258281
Uses the numpy input function with data from record features.
259282
"""
260-
x_cols: Dict[str, Any] = {feature: [] for feature in self.features}
261-
y_cols = []
262-
for record in [
263-
record
264-
async for record in sources.with_features(
265-
self.features + [self.parent.config.predict.NAME]
266-
)
267-
if record.feature(self.parent.config.predict.NAME)
268-
in self.classifications
269-
]:
270-
for feature, results in record.features(self.features).items():
271-
x_cols[feature].append(np.array(results))
272-
y_cols.append(
273-
self.classifications[
274-
record.feature(self.parent.config.predict.NAME)
275-
]
276-
)
277-
y_cols = np.array(y_cols)
278-
for feature in x_cols:
279-
x_cols[feature] = np.array(x_cols[feature])
283+
x_cols, y_cols = await self.sources_to_array(sources)
280284
self.logger.info("------ Record Data ------")
281285
self.logger.info("x_cols: %d", len(list(x_cols.values())[0]))
282286
self.logger.info("y_cols: %d", len(y_cols))
@@ -308,13 +312,7 @@ async def predict(
308312
"""
309313
Uses trained data to make a prediction about the quality of a record.
310314
"""
311-
if not os.path.isdir(self.model_dir_path):
312-
raise ModelNotTrained("Train model before prediction.")
313-
# Create the input function
314-
input_fn, predict = await self.predict_input_fn(records)
315-
# Makes predictions on classifications
316-
predictions = self.model.predict(input_fn=input_fn)
317-
target = self.parent.config.predict.NAME
315+
predict, predictions, target = await self.get_predictions(records)
318316
for record, pred_dict in zip(predict, predictions):
319317
class_id = pred_dict["class_ids"][0]
320318
probability = pred_dict["probabilities"][class_id]

model/tensorflow/dffml_model_tensorflow/dnnr.py

Lines changed: 25 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -3,41 +3,26 @@
33
record.
44
"""
55
import os
6-
import pathlib
7-
from typing import List, Dict, Any, AsyncIterator
6+
from typing import Dict, Any, AsyncIterator
87

98
import numpy as np
109

1110
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
1211
import tensorflow as tf
1312

13+
from dffml.base import config
1414
from dffml.record import Record
1515
from dffml.model.model import Model
1616
from dffml.model.accuracy import Accuracy
1717
from dffml.source.source import Sources
1818
from dffml.util.entrypoint import entrypoint
19-
from dffml.base import config, field
20-
from dffml.feature.feature import Feature, Features
2119

22-
from .dnnc import TensorflowModelContext
20+
from .dnnc import TensorflowModelContext, TensorflowBaseConfig
2321

2422

2523
@config
26-
class DNNRegressionModelConfig:
27-
predict: Feature = field("Feature name holding target values")
28-
features: Features = field("Features to train on")
29-
steps: int = field("Number of steps to train the model", default=3000)
30-
epochs: int = field(
31-
"Number of iterations to pass over all records in a source", default=30
32-
)
33-
directory: pathlib.Path = field(
34-
"Directory where state should be saved",
35-
default=pathlib.Path("~", ".cache", "dffml", "tensorflow"),
36-
)
37-
hidden: List[int] = field(
38-
"List length is the number of hidden layers in the network. Each entry in the list is the number of nodes in that hidden layer",
39-
default_factory=lambda: [12, 40, 15],
40-
)
24+
class DNNRegressionModelConfig(TensorflowBaseConfig):
25+
pass
4126

4227

4328
class DNNRegressionModelContext(TensorflowModelContext):
@@ -71,18 +56,7 @@ def model(self):
7156

7257
return self._model
7358

74-
async def training_input_fn(
75-
self,
76-
sources: Sources,
77-
batch_size=20,
78-
shuffle=False,
79-
epochs=1,
80-
**kwargs,
81-
):
82-
"""
83-
Uses the numpy input function with data from record features.
84-
"""
85-
self.logger.debug("Training on features: %r", self.features)
59+
async def sources_to_array(self, sources: Sources):
8660
x_cols: Dict[str, Any] = {feature: [] for feature in self.features}
8761
y_cols = []
8862

@@ -95,6 +69,22 @@ async def training_input_fn(
9569
y_cols = np.array(y_cols)
9670
for feature in x_cols:
9771
x_cols[feature] = np.array(x_cols[feature])
72+
73+
return x_cols, y_cols
74+
75+
async def training_input_fn(
76+
self,
77+
sources: Sources,
78+
batch_size=20,
79+
shuffle=False,
80+
epochs=1,
81+
**kwargs,
82+
):
83+
"""
84+
Uses the numpy input function with data from record features.
85+
"""
86+
self.logger.debug("Training on features: %r", self.features)
87+
x_cols, y_cols = await self.sources_to_array(sources)
9888
self.logger.info("------ Record Data ------")
9989
self.logger.info("x_cols: %d", len(list(x_cols.values())[0]))
10090
self.logger.info("y_cols: %d", len(y_cols))
@@ -120,17 +110,7 @@ async def evaluate_input_fn(
120110
"""
121111
Uses the numpy input function with data from record features.
122112
"""
123-
x_cols: Dict[str, Any] = {feature: [] for feature in self.features}
124-
y_cols = []
125-
126-
async for record in sources.with_features(self.all_features):
127-
for feature, results in record.features(self.features).items():
128-
x_cols[feature].append(np.array(results))
129-
y_cols.append(record.feature(self.parent.config.predict.NAME))
130-
131-
y_cols = np.array(y_cols)
132-
for feature in x_cols:
133-
x_cols[feature] = np.array(x_cols[feature])
113+
x_cols, y_cols = await self.sources_to_array(sources)
134114
self.logger.info("------ Record Data ------")
135115
self.logger.info("x_cols: %d", len(list(x_cols.values())[0]))
136116
self.logger.info("y_cols: %d", len(y_cols))
@@ -164,15 +144,8 @@ async def predict(
164144
"""
165145
Uses trained data to make a prediction about the quality of a record.
166146
"""
167-
168-
if not os.path.isdir(self.model_dir_path):
169-
raise NotADirectoryError("Model not trained")
170-
# Create the input function
171-
input_fn, predict_record = await self.predict_input_fn(records)
172-
# Makes predictions on
173-
predictions = self.model.predict(input_fn=input_fn)
174-
target = self.parent.config.predict.NAME
175-
for record, pred_dict in zip(predict_record, predictions):
147+
predict, predictions, target = await self.get_predictions(records)
148+
for record, pred_dict in zip(predict, predictions):
176149
# TODO Instead of float("nan") save accuracy value and use that.
177150
record.predicted(
178151
target, float(pred_dict["predictions"]), float("nan")

0 commit comments

Comments
 (0)