77import hashlib
88import inspect
99import pathlib
10+ from dataclasses import dataclass
1011from typing import List , Dict , Any , AsyncIterator , Type
1112
1213import numpy as np
2324from dffml .model .model import ModelContext , Model , ModelNotTrained
2425
2526
27+ @dataclass
28+ class TensorflowBaseConfig :
29+ predict : Feature = field ("Feature name holding target values" )
30+ features : Features = field ("Features to train on" )
31+ steps : int = field ("Number of steps to train the model" , default = 3000 )
32+ epochs : int = field (
33+ "Number of iterations to pass over all records in a source" , default = 30
34+ )
35+ directory : pathlib .Path = field (
36+ "Directory where state should be saved" ,
37+ default = pathlib .Path ("~" , ".cache" , "dffml" , "tensorflow" ),
38+ )
39+ hidden : List [int ] = field (
40+ "List length is the number of hidden layers in the network. Each entry in the list is the number of nodes in that hidden layer" ,
41+ default_factory = lambda : [12 , 40 , 15 ],
42+ )
43+
44+
2645class TensorflowModelContext (ModelContext ):
2746 """
2847 Tensorflow based model contexts should derive from this model context. As it
@@ -122,6 +141,17 @@ async def train(self, sources: Sources):
122141 input_fn = await self .training_input_fn (sources )
123142 self .model .train (input_fn = input_fn , steps = self .parent .config .steps )
124143
144+ async def get_predictions (self , records : Record ):
145+ if not os .path .isdir (self .model_dir_path ):
146+ raise ModelNotTrained ("Train model before prediction." )
147+ # Create the input function
148+ input_fn , predict = await self .predict_input_fn (records )
149+ # Makes predictions on classifications
150+ predictions = self .model .predict (input_fn = input_fn )
151+ target = self .parent .config .predict .NAME
152+
153+ return predict , predictions , target
154+
125155 @property
126156 @abc .abstractmethod
127157 def model (self ):
@@ -131,29 +161,17 @@ def model(self):
131161
132162
133163@config
134- class DNNClassifierModelConfig :
135- predict : Feature = field ("Feature name holding predict value" )
136- classifications : List [ str ] = field ( "Options for value of classification" )
137- features : Features = field ( "Features to train on" )
164+ class DNNClassifierModelConfig ( TensorflowBaseConfig ) :
165+ classifications : List [ str ] = field (
166+ "Options for value of classification" , default = None
167+ )
138168 clstype : Type = field ("Data type of classifications values" , default = str )
139169 batchsize : int = field (
140170 "Number records to pass through in an epoch" , default = 20
141171 )
142172 shuffle : bool = field (
143173 "Randomise order of records in a batch" , default = True
144174 )
145- steps : int = field ("Number of steps to train the model" , default = 3000 )
146- epochs : int = field (
147- "Number of iterations to pass over all records in a source" , default = 30
148- )
149- directory : pathlib .Path = field (
150- "Directory where state should be saved" ,
151- default = pathlib .Path ("~" , ".cache" , "dffml" , "tensorflow" ),
152- )
153- hidden : List [int ] = field (
154- "List length is the number of hidden layers in the network. Each entry in the list is the number of nodes in that hidden layer" ,
155- default_factory = lambda : [12 , 40 , 15 ],
156- )
157175
158176 def __post_init__ (self ):
159177 self .classifications = list (map (self .clstype , self .classifications ))
@@ -212,11 +230,7 @@ def model(self):
212230 )
213231 return self ._model
214232
215- async def training_input_fn (self , sources : Sources , ** kwargs ):
216- """
217- Uses the numpy input function with data from record features.
218- """
219- self .logger .debug ("Training on features: %r" , self .features )
233+ async def sources_to_array (self , sources : Sources ):
220234 x_cols : Dict [str , Any ] = {feature : [] for feature in self .features }
221235 y_cols = []
222236 for record in [
@@ -239,6 +253,15 @@ async def training_input_fn(self, sources: Sources, **kwargs):
239253 y_cols = np .array (y_cols )
240254 for feature in x_cols :
241255 x_cols [feature ] = np .array (x_cols [feature ])
256+
257+ return x_cols , y_cols
258+
259+ async def training_input_fn (self , sources : Sources , ** kwargs ):
260+ """
261+ Uses the numpy input function with data from record features.
262+ """
263+ self .logger .debug ("Training on features: %r" , self .features )
264+ x_cols , y_cols = await self .sources_to_array (sources )
242265 self .logger .info ("------ Record Data ------" )
243266 self .logger .info ("x_cols: %d" , len (list (x_cols .values ())[0 ]))
244267 self .logger .info ("y_cols: %d" , len (y_cols ))
@@ -257,26 +280,7 @@ async def accuracy_input_fn(self, sources: Sources, **kwargs):
257280 """
258281 Uses the numpy input function with data from record features.
259282 """
260- x_cols : Dict [str , Any ] = {feature : [] for feature in self .features }
261- y_cols = []
262- for record in [
263- record
264- async for record in sources .with_features (
265- self .features + [self .parent .config .predict .NAME ]
266- )
267- if record .feature (self .parent .config .predict .NAME )
268- in self .classifications
269- ]:
270- for feature , results in record .features (self .features ).items ():
271- x_cols [feature ].append (np .array (results ))
272- y_cols .append (
273- self .classifications [
274- record .feature (self .parent .config .predict .NAME )
275- ]
276- )
277- y_cols = np .array (y_cols )
278- for feature in x_cols :
279- x_cols [feature ] = np .array (x_cols [feature ])
283+ x_cols , y_cols = await self .sources_to_array (sources )
280284 self .logger .info ("------ Record Data ------" )
281285 self .logger .info ("x_cols: %d" , len (list (x_cols .values ())[0 ]))
282286 self .logger .info ("y_cols: %d" , len (y_cols ))
@@ -308,13 +312,7 @@ async def predict(
308312 """
309313 Uses trained data to make a prediction about the quality of a record.
310314 """
311- if not os .path .isdir (self .model_dir_path ):
312- raise ModelNotTrained ("Train model before prediction." )
313- # Create the input function
314- input_fn , predict = await self .predict_input_fn (records )
315- # Makes predictions on classifications
316- predictions = self .model .predict (input_fn = input_fn )
317- target = self .parent .config .predict .NAME
315+ predict , predictions , target = await self .get_predictions (records )
318316 for record , pred_dict in zip (predict , predictions ):
319317 class_id = pred_dict ["class_ids" ][0 ]
320318 probability = pred_dict ["probabilities" ][class_id ]
0 commit comments