66import os
77import json
88import hashlib
9+ import pathlib
910from pathlib import Path
1011from typing import AsyncIterator , Tuple , Any , NamedTuple
1112
2223
2324
2425class ScikitConfig (ModelConfig , NamedTuple ):
25- directory : str
26+ directory : pathlib . Path
2627 predict : Feature
2728 features : Features
2829 tcluster : Feature
@@ -55,14 +56,13 @@ def _feature_predict_hash(self):
5556 "" .join ([params ] + self .features ).encode ()
5657 ).hexdigest ()
5758
58- def _filename (self ):
59- return os .path .join (
60- self .parent .config .directory , self ._features_hash + ".joblib"
61- )
59+ @property
60+ def _filepath (self ):
61+ return self .parent .config .directory / (self ._features_hash + ".joblib" )
6262
6363 async def __aenter__ (self ):
64- if os . path . isfile ( self . _filename () ):
65- self .clf = joblib .load (self ._filename ( ))
64+ if self . _filepath . is_file ( ):
65+ self .clf = joblib .load (str ( self ._filepath ))
6666 else :
6767 config = self .parent .config ._asdict ()
6868 del config ["directory" ]
@@ -88,10 +88,10 @@ async def train(self, sources: Sources):
8888 ydata = np .array (df [self .parent .config .predict .NAME ])
8989 self .logger .info ("Number of input records: {}" .format (len (xdata )))
9090 self .clf .fit (xdata , ydata )
91- joblib .dump (self .clf , self ._filename ( ))
91+ joblib .dump (self .clf , str ( self ._filepath ))
9292
9393 async def accuracy (self , sources : Sources ) -> Accuracy :
94- if not os . path . isfile ( self . _filename () ):
94+ if not self . _filepath . is_file ( ):
9595 raise ModelNotTrained ("Train model before assessing for accuracy." )
9696 data = []
9797 async for record in sources .with_features (self .features ):
@@ -110,7 +110,7 @@ async def accuracy(self, sources: Sources) -> Accuracy:
110110 async def predict (
111111 self , records : AsyncIterator [Record ]
112112 ) -> AsyncIterator [Tuple [Record , Any , float ]]:
113- if not os . path . isfile ( self . _filename () ):
113+ if not self . _filepath . is_file ( ):
114114 raise ModelNotTrained ("Train model before prediction." )
115115 async for record in records :
116116 feature_data = record .features (self .features )
@@ -132,8 +132,8 @@ async def predict(
132132
133133class ScikitContextUnsprvised (ScikitContext ):
134134 async def __aenter__ (self ):
135- if os . path . isfile ( self . _filename () ):
136- self .clf = joblib .load (self ._filename ( ))
135+ if self . _filepath . is_file ( ):
136+ self .clf = joblib .load (str ( self ._filepath ))
137137 else :
138138 config = self .parent .config ._asdict ()
139139 del config ["directory" ]
@@ -152,10 +152,10 @@ async def train(self, sources: Sources):
152152 xdata = np .array (df )
153153 self .logger .info ("Number of input records: {}" .format (len (xdata )))
154154 self .clf .fit (xdata )
155- joblib .dump (self .clf , self ._filename ( ))
155+ joblib .dump (self .clf , str ( self ._filepath ))
156156
157157 async def accuracy (self , sources : Sources ) -> Accuracy :
158- if not os . path . isfile ( self . _filename () ):
158+ if not self . _filepath . is_file ( ):
159159 raise ModelNotTrained ("Train model before assessing for accuracy." )
160160 data = []
161161 target = []
@@ -205,7 +205,7 @@ async def accuracy(self, sources: Sources) -> Accuracy:
205205 async def predict (
206206 self , records : AsyncIterator [Record ]
207207 ) -> AsyncIterator [Tuple [Record , Any , float ]]:
208- if not os . path . isfile ( self . _filename () ):
208+ if not self . _filepath . is_file ( ):
209209 raise ModelNotTrained ("Train model before prediction." )
210210 estimator_type = self .clf ._estimator_type
211211 if estimator_type is "clusterer" :
@@ -240,28 +240,27 @@ def __init__(self, config) -> None:
240240 super ().__init__ (config )
241241 self .saved = {}
242242
243- def _filename ( self ):
244- return os . path . join (
245- self .config .directory ,
243+ @ property
244+ def _filepath ( self ):
245+ return self .config .directory / (
246246 hashlib .sha384 (self .config .predict .NAME .encode ()).hexdigest ()
247- + ".json" ,
247+ + ".json"
248248 )
249249
250250 async def __aenter__ (self ) -> "Scikit" :
251- path = Path (self ._filename ())
252- if path .is_file ():
253- self .saved = json .loads (path .read_text ())
251+ if self ._filepath .is_file ():
252+ self .saved = json .loads (self ._filepath .read_text ())
254253 return self
255254
256255 async def __aexit__ (self , exc_type , exc_value , traceback ):
257- Path ( self ._filename ()) .write_text (json .dumps (self .saved ))
256+ self ._filepath .write_text (json .dumps (self .saved ))
258257
259258
260259class ScikitUnsprvised (Scikit ):
261- def _filename (self ):
260+ @property
261+ def _filepath (self ):
262262 model_name = self .SCIKIT_MODEL .__name__
263- return os .path .join (
264- self .config .directory ,
263+ return self .config .directory / (
265264 hashlib .sha384 (
266265 (
267266 "" .join (
@@ -272,5 +271,5 @@ def _filename(self):
272271 )
273272 ).encode ()
274273 ).hexdigest ()
275- + ".json" ,
274+ + ".json"
276275 )
0 commit comments