Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit e5e41b2

Browse files
committed
model: scratch: Use simplified model API
Signed-off-by: John Andersen <[email protected]>
1 parent 60f3b05 commit e5e41b2

File tree

5 files changed

+245
-229
lines changed

5 files changed

+245
-229
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2222
- Test scikit LR documentation examples in CI
2323
- Create a fresh archive of the git repo for release instead of cleaning
2424
existing repo with `git clean` for development service release command.
25+
- Simplified SLR tests for scratch model
2526

2627
## [0.3.4] - 2020-02-28
2728
### Added

docs/plugins/dffml_model.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -548,9 +548,9 @@ hash of their feature names.
548548

549549
- Features to train on
550550

551-
- directory: String
551+
- directory: Path
552552

553-
- default: /home/user/.cache/dffml/scratch
553+
- default: ~/.cache/dffml/scratch
554554
- Directory where state should be saved
555555

556556
dffml_model_scikit
Lines changed: 112 additions & 150 deletions
Original file line numberDiff line numberDiff line change
@@ -1,93 +1,140 @@
1-
# SPDX-License-Identifier: MIT
2-
# Copyright (c) 2019 Intel Corporation
3-
"""
4-
Description of what this model does
5-
"""
6-
import os
7-
import json
8-
import hashlib
1+
import pathlib
92
from typing import AsyncIterator, Tuple, Any
103

114
import numpy as np
125

13-
from dffml.record import Record
14-
from dffml.base import config, field
15-
from dffml.source.source import Sources
16-
from dffml.model.accuracy import Accuracy
17-
from dffml.model.model import ModelContext, Model, ModelNotTrained
18-
from dffml.util.entrypoint import entrypoint
19-
from dffml.feature.feature import Feature, Features
6+
from dffml import (
7+
config,
8+
field,
9+
entrypoint,
10+
SimpleModel,
11+
ModelNotTrained,
12+
Accuracy,
13+
Feature,
14+
Features,
15+
Sources,
16+
Record,
17+
)
2018

2119

2220
@config
2321
class SLRConfig:
2422
predict: Feature = field("Label or the value to be predicted")
2523
features: Features = field("Features to train on")
26-
directory: str = field(
24+
directory: pathlib.Path = field(
2725
"Directory where state should be saved",
28-
default=os.path.join(
29-
os.path.expanduser("~"), ".cache", "dffml", "scratch"
30-
),
26+
default=pathlib.Path("~", ".cache", "dffml", "scratch"),
3127
)
3228

3329

34-
class SLRContext(ModelContext):
35-
def __init__(self, parent):
36-
super().__init__(parent)
30+
@entrypoint("scratchslr")
31+
class SLR(SimpleModel):
32+
r"""
33+
Simple Linear Regression Model for 2 variables implemented from scratch.
34+
Models are saved under the ``directory`` in subdirectories named after the
35+
hash of their feature names.
36+
37+
.. code-block:: console
38+
39+
$ cat > dataset.csv << EOF
40+
Years,Salary
41+
1,40
42+
2,50
43+
3,60
44+
4,70
45+
5,80
46+
EOF
47+
$ dffml train \
48+
-model scratchslr \
49+
-model-features Years:int:1 \
50+
-model-predict Salary:float:1 \
51+
-sources f=csv \
52+
-source-filename dataset.csv \
53+
-log debug
54+
$ dffml accuracy \
55+
-model scratchslr \
56+
-model-features Years:int:1 \
57+
-model-predict Salary:float:1 \
58+
-sources f=csv \
59+
-source-filename dataset.csv \
60+
-log debug
61+
1.0
62+
$ echo -e 'Years,Salary\n6,0\n' | \
63+
dffml predict all \
64+
-model scratchslr \
65+
-model-features Years:int:1 \
66+
-model-predict Salary:float:1 \
67+
-sources f=csv \
68+
-source-filename /dev/stdin \
69+
-log debug
70+
[
71+
{
72+
"extra": {},
73+
"features": {
74+
"Salary": 0,
75+
"Years": 6
76+
},
77+
"last_updated": "2019-07-19T09:46:45Z",
78+
"prediction": {
79+
"Salary": {
80+
"confidence": 1.0,
81+
"value": 90.0
82+
}
83+
},
84+
"key": "0"
85+
}
86+
]
87+
88+
"""
89+
90+
# The configuration class needs to be set as the CONFIG property
91+
CONFIG = SLRConfig
92+
# Simple Linear Regression only supports training on a single feature
93+
NUM_SUPPORTED_FEATURES = 1
94+
# We only support single dimensional values, non-matrix / array
95+
SUPPORTED_LENGTHS = [1]
96+
97+
def __init__(self, config):
98+
super().__init__(config)
3799
self.xData = np.array([])
38100
self.yData = np.array([])
39-
self.features = self.applicable_features(self.parent.config.features)
40-
self._features_hash_ = hashlib.sha384(
41-
("".join(sorted(self.features))).encode()
42-
).hexdigest()
43101

44102
@property
45103
def regression_line(self):
46-
return self.parent.saved.get(self._features_hash_, None)
104+
"""
105+
Load regression_line from disk, if it hasn't been set yet, return None
106+
"""
107+
return self.storage.get("regression_line", None)
47108

48109
@regression_line.setter
49110
def regression_line(self, rline):
50-
self.parent.saved[self._features_hash_] = rline
51-
52-
def applicable_features(self, features):
53-
usable = []
54-
if len(features) != 1:
55-
raise ValueError(
56-
"Simple Linear Regression doesn't support features other than 1"
57-
)
58-
for feature in features:
59-
if feature.dtype() != int and feature.dtype() != float:
60-
raise ValueError(
61-
"Simple Linear Regression only supports int or float feature"
62-
)
63-
if feature.length() != 1:
64-
raise ValueError(
65-
"Simple LR only supports single values (non-matrix / array)"
66-
)
67-
usable.append(feature.NAME)
68-
return sorted(usable)
69-
70-
async def predict_input(self, x):
111+
"""
112+
Set regression_line in self.storage so it will be saved to disk
113+
"""
114+
self.storage["regression_line"] = rline
115+
116+
def predict_input(self, x):
117+
"""
118+
Use the regression line to make a prediction by returning ``m * x + b``.
119+
"""
71120
prediction = self.regression_line[0] * x + self.regression_line[1]
72121
self.logger.debug(
73122
"Predicted Value of {} {}:".format(
74-
self.parent.config.predict.NAME, prediction
123+
self.config.predict.NAME, prediction
75124
)
76125
)
77126
return prediction
78127

79-
async def squared_error(self, ys, yline):
128+
def squared_error(self, ys, yline):
80129
return sum((ys - yline) ** 2)
81130

82-
async def coeff_of_deter(self, ys, regression_line):
131+
def coeff_of_deter(self, ys, regression_line):
83132
y_mean_line = [np.mean(ys) for y in ys]
84-
squared_error_mean = await self.squared_error(ys, y_mean_line)
85-
squared_error_regression = await self.squared_error(
86-
ys, regression_line
87-
)
133+
squared_error_mean = self.squared_error(ys, y_mean_line)
134+
squared_error_regression = self.squared_error(ys, regression_line)
88135
return 1 - (squared_error_regression / squared_error_mean)
89136

90-
async def best_fit_line(self):
137+
def best_fit_line(self):
91138
self.logger.debug(
92139
"Number of input records: {}".format(len(self.xData))
93140
)
@@ -100,23 +147,24 @@ async def best_fit_line(self):
100147
)
101148
b = mean_y - (m * mean_x)
102149
regression_line = [m * x + b for x in x]
103-
accuracy = await self.coeff_of_deter(y, regression_line)
150+
accuracy = self.coeff_of_deter(y, regression_line)
104151
return (m, b, accuracy)
105152

106153
async def train(self, sources: Sources):
107154
async for record in sources.with_features(
108-
self.features + [self.parent.config.predict.NAME]
155+
self.features + [self.config.predict.NAME]
109156
):
110157
feature_data = record.features(
111-
self.features + [self.parent.config.predict.NAME]
158+
self.features + [self.config.predict.NAME]
112159
)
113160
self.xData = np.append(self.xData, feature_data[self.features[0]])
114161
self.yData = np.append(
115-
self.yData, feature_data[self.parent.config.predict.NAME]
162+
self.yData, feature_data[self.config.predict.NAME]
116163
)
117-
self.regression_line = await self.best_fit_line()
164+
self.regression_line = self.best_fit_line()
118165

119166
async def accuracy(self, sources: Sources) -> Accuracy:
167+
# Ensure the model has been trained before we try to make a prediction
120168
if self.regression_line is None:
121169
raise ModelNotTrained("Train model before assessing for accuracy.")
122170
accuracy_value = self.regression_line[2]
@@ -125,101 +173,15 @@ async def accuracy(self, sources: Sources) -> Accuracy:
125173
async def predict(
126174
self, records: AsyncIterator[Record]
127175
) -> AsyncIterator[Tuple[Record, Any, float]]:
176+
# Ensure the model has been trained before we try to make a prediction
128177
if self.regression_line is None:
129178
raise ModelNotTrained("Train model before prediction.")
130-
target = self.parent.config.predict.NAME
179+
target = self.config.predict.NAME
131180
async for record in records:
132181
feature_data = record.features(self.features)
133182
record.predicted(
134183
target,
135-
await self.predict_input(feature_data[self.features[0]]),
184+
self.predict_input(feature_data[self.features[0]]),
136185
self.regression_line[2],
137186
)
138187
yield record
139-
140-
141-
@entrypoint("slr")
142-
class SLR(Model):
143-
"""
144-
Simple Linear Regression Model for 2 variables implemented from scratch.
145-
Models are saved under the ``directory`` in subdirectories named after the
146-
hash of their feature names.
147-
148-
.. code-block:: console
149-
150-
$ cat > dataset.csv << EOF
151-
Years,Salary
152-
1,40
153-
2,50
154-
3,60
155-
4,70
156-
5,80
157-
EOF
158-
$ dffml train \\
159-
-model scratchslr \\
160-
-model-features Years:int:1 \\
161-
-model-predict Salary:float:1 \\
162-
-sources f=csv \\
163-
-source-filename dataset.csv \\
164-
-log debug
165-
$ dffml accuracy \\
166-
-model scratchslr \\
167-
-model-features Years:int:1 \\
168-
-model-predict Salary:float:1 \\
169-
-sources f=csv \\
170-
-source-filename dataset.csv \\
171-
-log debug
172-
1.0
173-
$ echo -e 'Years,Salary\\n6,0\\n' | \\
174-
dffml predict all \\
175-
-model scratchslr \\
176-
-model-features Years:int:1 \\
177-
-model-predict Salary:float:1 \\
178-
-sources f=csv \\
179-
-source-filename /dev/stdin \\
180-
-log debug
181-
[
182-
{
183-
"extra": {},
184-
"features": {
185-
"Salary": 0,
186-
"Years": 6
187-
},
188-
"last_updated": "2019-07-19T09:46:45Z",
189-
"prediction": {
190-
"Salary": {
191-
"confidence": 1.0,
192-
"value": 90.0
193-
}
194-
},
195-
"key": "0"
196-
}
197-
]
198-
199-
"""
200-
201-
CONTEXT = SLRContext
202-
CONFIG = SLRConfig
203-
204-
def __init__(self, config: SLRConfig) -> None:
205-
super().__init__(config)
206-
self.saved = {}
207-
208-
def _filename(self):
209-
return os.path.join(
210-
self.config.directory,
211-
hashlib.sha384(self.config.predict.NAME.encode()).hexdigest()
212-
+ ".json",
213-
)
214-
215-
async def __aenter__(self) -> SLRContext:
216-
filename = self._filename()
217-
if os.path.isfile(filename):
218-
with open(filename, "r") as read:
219-
self.saved = json.load(read)
220-
return self
221-
222-
async def __aexit__(self, exc_type, exc_value, traceback):
223-
filename = self._filename()
224-
with open(filename, "w") as write:
225-
json.dump(self.saved, write)

0 commit comments

Comments
 (0)