Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit e368d13

Browse files
committed
model: Switch directory config parameter to pathlib.Path
Signed-off-by: John Andersen <[email protected]>
1 parent b922a6e commit e368d13

File tree

12 files changed

+63
-62
lines changed

12 files changed

+63
-62
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2727
- Test tensorflow DNNClassifier documentation exaples in CI
2828
- config directories and files associated with ConfigLoaders have been renamed
2929
to configloader.
30+
- Model config directory parameters are now `pathlib.Path` objects
3031

3132
## [0.3.4] - 2020-02-28
3233
### Added

dffml/model/model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,14 @@ class Model(BaseDataFlowFacilitatorObject):
7979

8080
CONFIG = ModelConfig
8181

82+
def __init__(self, config):
83+
super().__init__(config)
84+
# TODO Just in case its a string. We should make it so that on
85+
# instantiation of an @config we convert properties to their correct
86+
# types.
87+
if isinstance(getattr(self.config, "directory", None), str):
88+
self.config.directory = pathlib.Path(self.config.directory)
89+
8290
def __call__(self) -> ModelContext:
8391
self._make_config_directory()
8492
return self.CONTEXT(self)

docs/plugins/dffml_model.rst

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,9 @@ Example usage of Tensorflow DNNClassifier model using python API
127127
- default: 30
128128
- Number of iterations to pass over all records in a source
129129

130-
- directory: String
130+
- directory: Path
131131

132-
- default: /home/user/.cache/dffml/tensorflow
132+
- default: ~/.cache/dffml/tensorflow
133133
- Directory where state should be saved
134134

135135
- hidden: List of integers
@@ -248,9 +248,9 @@ predict).
248248
- default: 30
249249
- Number of iterations to pass over all records in a source
250250

251-
- directory: String
251+
- directory: Path
252252

253-
- default: /home/user/.cache/dffml/tensorflow
253+
- default: ~/.cache/dffml/tensorflow
254254
- Directory where state should be saved
255255

256256
- hidden: List of integers
@@ -426,9 +426,9 @@ Implemented using Tensorflow hub pretrained models.
426426
- default: 10
427427
- Number of iterations to pass over all records in a source
428428

429-
- directory: String
429+
- directory: Path
430430

431-
- default: /home/user/.cache/dffml/tensorflow_hub
431+
- default: ~/.cache/dffml/tensorflow_hub
432432
- Directory where state should be saved
433433

434434
dffml_model_scratch
@@ -832,8 +832,8 @@ Ensure that `predict` and `accuracy` for these algorithms uses training data.
832832

833833
- Features to train on
834834

835-
- directory: String
835+
- directory: Path
836836

837-
- default: /home/user/.cache/dffml/scikit-{Entrypoint}
837+
- default: ~/.cache/dffml/scikit-{entrypoint}
838838
- Directory where state should be saved
839839

model/scikit/dffml_model_scikit/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,9 +309,9 @@
309309
310310
- Features to train on
311311
312-
- directory: String
312+
- directory: Path
313313
314-
- default: /home/user/.cache/dffml/scikit-{Entrypoint}
314+
- default: ~/.cache/dffml/scikit-{entrypoint}
315315
- Directory where state should be saved
316316
317317
"""

model/scikit/dffml_model_scikit/scikit_base.py

Lines changed: 26 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import os
77
import json
88
import hashlib
9+
import pathlib
910
from pathlib import Path
1011
from typing import AsyncIterator, Tuple, Any, NamedTuple
1112

@@ -22,7 +23,7 @@
2223

2324

2425
class ScikitConfig(ModelConfig, NamedTuple):
25-
directory: str
26+
directory: pathlib.Path
2627
predict: Feature
2728
features: Features
2829
tcluster: Feature
@@ -55,14 +56,13 @@ def _feature_predict_hash(self):
5556
"".join([params] + self.features).encode()
5657
).hexdigest()
5758

58-
def _filename(self):
59-
return os.path.join(
60-
self.parent.config.directory, self._features_hash + ".joblib"
61-
)
59+
@property
60+
def _filepath(self):
61+
return self.parent.config.directory / (self._features_hash + ".joblib")
6262

6363
async def __aenter__(self):
64-
if os.path.isfile(self._filename()):
65-
self.clf = joblib.load(self._filename())
64+
if self._filepath.is_file():
65+
self.clf = joblib.load(str(self._filepath))
6666
else:
6767
config = self.parent.config._asdict()
6868
del config["directory"]
@@ -88,10 +88,10 @@ async def train(self, sources: Sources):
8888
ydata = np.array(df[self.parent.config.predict.NAME])
8989
self.logger.info("Number of input records: {}".format(len(xdata)))
9090
self.clf.fit(xdata, ydata)
91-
joblib.dump(self.clf, self._filename())
91+
joblib.dump(self.clf, str(self._filepath))
9292

9393
async def accuracy(self, sources: Sources) -> Accuracy:
94-
if not os.path.isfile(self._filename()):
94+
if not self._filepath.is_file():
9595
raise ModelNotTrained("Train model before assessing for accuracy.")
9696
data = []
9797
async for record in sources.with_features(self.features):
@@ -110,7 +110,7 @@ async def accuracy(self, sources: Sources) -> Accuracy:
110110
async def predict(
111111
self, records: AsyncIterator[Record]
112112
) -> AsyncIterator[Tuple[Record, Any, float]]:
113-
if not os.path.isfile(self._filename()):
113+
if not self._filepath.is_file():
114114
raise ModelNotTrained("Train model before prediction.")
115115
async for record in records:
116116
feature_data = record.features(self.features)
@@ -132,8 +132,8 @@ async def predict(
132132

133133
class ScikitContextUnsprvised(ScikitContext):
134134
async def __aenter__(self):
135-
if os.path.isfile(self._filename()):
136-
self.clf = joblib.load(self._filename())
135+
if self._filepath.is_file():
136+
self.clf = joblib.load(str(self._filepath))
137137
else:
138138
config = self.parent.config._asdict()
139139
del config["directory"]
@@ -152,10 +152,10 @@ async def train(self, sources: Sources):
152152
xdata = np.array(df)
153153
self.logger.info("Number of input records: {}".format(len(xdata)))
154154
self.clf.fit(xdata)
155-
joblib.dump(self.clf, self._filename())
155+
joblib.dump(self.clf, str(self._filepath))
156156

157157
async def accuracy(self, sources: Sources) -> Accuracy:
158-
if not os.path.isfile(self._filename()):
158+
if not self._filepath.is_file():
159159
raise ModelNotTrained("Train model before assessing for accuracy.")
160160
data = []
161161
target = []
@@ -205,7 +205,7 @@ async def accuracy(self, sources: Sources) -> Accuracy:
205205
async def predict(
206206
self, records: AsyncIterator[Record]
207207
) -> AsyncIterator[Tuple[Record, Any, float]]:
208-
if not os.path.isfile(self._filename()):
208+
if not self._filepath.is_file():
209209
raise ModelNotTrained("Train model before prediction.")
210210
estimator_type = self.clf._estimator_type
211211
if estimator_type is "clusterer":
@@ -240,28 +240,27 @@ def __init__(self, config) -> None:
240240
super().__init__(config)
241241
self.saved = {}
242242

243-
def _filename(self):
244-
return os.path.join(
245-
self.config.directory,
243+
@property
244+
def _filepath(self):
245+
return self.config.directory / (
246246
hashlib.sha384(self.config.predict.NAME.encode()).hexdigest()
247-
+ ".json",
247+
+ ".json"
248248
)
249249

250250
async def __aenter__(self) -> "Scikit":
251-
path = Path(self._filename())
252-
if path.is_file():
253-
self.saved = json.loads(path.read_text())
251+
if self._filepath.is_file():
252+
self.saved = json.loads(self._filepath.read_text())
254253
return self
255254

256255
async def __aexit__(self, exc_type, exc_value, traceback):
257-
Path(self._filename()).write_text(json.dumps(self.saved))
256+
self._filepath.write_text(json.dumps(self.saved))
258257

259258

260259
class ScikitUnsprvised(Scikit):
261-
def _filename(self):
260+
@property
261+
def _filepath(self):
262262
model_name = self.SCIKIT_MODEL.__name__
263-
return os.path.join(
264-
self.config.directory,
263+
return self.config.directory / (
265264
hashlib.sha384(
266265
(
267266
"".join(
@@ -272,5 +271,5 @@ def _filename(self):
272271
)
273272
).encode()
274273
).hexdigest()
275-
+ ".json",
274+
+ ".json"
276275
)

model/scikit/dffml_model_scikit/scikit_models.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66
import os
77
import sys
8+
import pathlib
89

910
from sklearn.neural_network import MLPClassifier
1011
from sklearn.neighbors import KNeighborsClassifier
@@ -221,14 +222,11 @@ def applicable_features(self, features):
221222
dffml_config_properties = {
222223
**{
223224
"directory": (
224-
str,
225+
pathlib.Path,
225226
field(
226227
"Directory where state should be saved",
227-
default=os.path.join(
228-
os.path.expanduser("~"),
229-
".cache",
230-
"dffml",
231-
f"scikit-{entry_point_name}",
228+
default=pathlib.Path(
229+
"~", ".cache", "dffml", f"scikit-{entry_point_name}",
232230
),
233231
),
234232
),

model/tensorflow/dffml_model_tensorflow/dnnc.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import abc
77
import hashlib
88
import inspect
9+
import pathlib
910
from typing import List, Dict, Any, AsyncIterator, Type
1011

1112
import numpy as np
@@ -145,11 +146,9 @@ class DNNClassifierModelConfig:
145146
epochs: int = field(
146147
"Number of iterations to pass over all records in a source", default=30
147148
)
148-
directory: str = field(
149+
directory: pathlib.Path = field(
149150
"Directory where state should be saved",
150-
default=os.path.join(
151-
os.path.expanduser("~"), ".cache", "dffml", "tensorflow"
152-
),
151+
default=pathlib.Path("~", ".cache", "dffml", "tensorflow"),
153152
)
154153
hidden: List[int] = field(
155154
"List length is the number of hidden layers in the network. Each entry in the list is the number of nodes in that hidden layer",

model/tensorflow/dffml_model_tensorflow/dnnr.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
record.
44
"""
55
import os
6+
import pathlib
67
from typing import List, Dict, Any, AsyncIterator
78

89
import numpy as np
@@ -29,11 +30,9 @@ class DNNRegressionModelConfig:
2930
epochs: int = field(
3031
"Number of iterations to pass over all records in a source", default=30
3132
)
32-
directory: str = field(
33+
directory: pathlib.Path = field(
3334
"Directory where state should be saved",
34-
default=os.path.join(
35-
os.path.expanduser("~"), ".cache", "dffml", "tensorflow"
36-
),
35+
default=pathlib.Path("~", ".cache", "dffml", "tensorflow"),
3736
)
3837
hidden: List[int] = field(
3938
"List length is the number of hidden layers in the network. Each entry in the list is the number of nodes in that hidden layer",

model/tensorflow/tests/test_dnnc.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import random
3+
import pathlib
34
import tempfile
45
from typing import Type
56

@@ -84,9 +85,7 @@ async def test_config(self):
8485
)
8586
self.assertEqual(
8687
config.directory,
87-
os.path.join(
88-
os.path.expanduser("~"), ".cache", "dffml", "tensorflow"
89-
),
88+
pathlib.Path("~", ".cache", "dffml", "tensorflow"),
9089
)
9190
self.assertEqual(config.steps, 3000)
9291
self.assertEqual(config.epochs, 30)

model/tensorflow/tests/test_dnnr.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import random
3+
import pathlib
34
import tempfile
45
from typing import Type
56

@@ -95,9 +96,7 @@ async def test_config(self):
9596
)
9697
self.assertEqual(
9798
config.directory,
98-
os.path.join(
99-
os.path.expanduser("~"), ".cache", "dffml", "tensorflow"
100-
),
99+
pathlib.Path("~", ".cache", "dffml", "tensorflow"),
101100
)
102101
self.assertEqual(config.steps, 3000)
103102
self.assertEqual(config.epochs, 30)

0 commit comments

Comments
 (0)