Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit e1cd112

Browse files
authored
model: Move scratch slr into main pacakge
- Moved SLR into the main dffml package - Removed scratch:slr Fixes: #500 Fixes: #499
1 parent e51f202 commit e1cd112

File tree

8 files changed

+130
-193
lines changed

8 files changed

+130
-193
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2626
- Treat `"~"` as the the home directory rather than a literal
2727
- Windows support by selecting `asyncio.ProactorEventLoop` and not using
2828
`asyncio.FastChildWatcher`.
29+
- Moved SLR into the main dffml package and removed `scratch:slr`.
2930

3031
## [0.3.5] - 2020-03-10
3132
### Added

dffml/model/slr.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import pathlib
2+
import statistics
3+
from typing import AsyncIterator, Tuple, Any, Type, List
4+
5+
from dffml import (
6+
config,
7+
field,
8+
entrypoint,
9+
SimpleModel,
10+
ModelNotTrained,
11+
Accuracy,
12+
Feature,
13+
Features,
14+
Sources,
15+
Record,
16+
)
17+
18+
19+
def matrix_subtract(one, two):
20+
return [
21+
one_element - two_element for one_element, two_element in zip(one, two)
22+
]
23+
24+
25+
def matrix_multiply(one, two):
26+
return [
27+
one_element * two_element for one_element, two_element in zip(one, two)
28+
]
29+
30+
31+
def squared_error(y, line):
32+
return sum(map(lambda element: element ** 2, matrix_subtract(y, line)))
33+
34+
35+
def coeff_of_deter(y, regression_line):
36+
y_mean_line = [statistics.mean(y)] * len(y)
37+
squared_error_mean = squared_error(y, y_mean_line)
38+
squared_error_regression = squared_error(y, regression_line)
39+
return 1 - (squared_error_regression / squared_error_mean)
40+
41+
42+
def best_fit_line(x, y):
43+
mean_x = statistics.mean(x)
44+
mean_y = statistics.mean(y)
45+
m = (mean_x * mean_y - statistics.mean(matrix_multiply(x, y))) / (
46+
(mean_x ** 2) - statistics.mean(matrix_multiply(x, x))
47+
)
48+
b = mean_y - (m * mean_x)
49+
regression_line = [m * x + b for x in x]
50+
accuracy = coeff_of_deter(y, regression_line)
51+
return (m, b, accuracy)
52+
53+
54+
@config
55+
class SLRModelConfig:
56+
predict: Feature = field("Label or the value to be predicted")
57+
features: Features = field("Features to train on. For SLR only 1 allowed")
58+
directory: pathlib.Path = field(
59+
"Directory where state should be saved",
60+
default=pathlib.Path("~", ".cache", "dffml", "slr"),
61+
)
62+
63+
64+
@entrypoint("slr")
65+
class SLRModel(SimpleModel):
66+
# The configuration class needs to be set as the CONFIG property
67+
CONFIG: Type[SLRModelConfig] = SLRModelConfig
68+
# Simple Linear Regression only supports training on a single feature.
69+
# Do not define NUM_SUPPORTED_FEATURES if you support arbitrary numbers of
70+
# features.
71+
NUM_SUPPORTED_FEATURES: int = 1
72+
# We only support single dimensional values, non-matrix / array
73+
# Do not define SUPPORTED_LENGTHS if you support arbitrary dimensions
74+
SUPPORTED_LENGTHS: List[int] = [1]
75+
76+
async def train(self, sources: Sources) -> None:
77+
# X and Y data
78+
x = []
79+
y = []
80+
# Go through all records that have the feature we're training on and the
81+
# feature we want to predict. Since our model only supports 1 feature,
82+
# the self.features list will only have one element at index 0.
83+
async for record in sources.with_features(
84+
self.features + [self.config.predict.NAME]
85+
):
86+
x.append(record.feature(self.features[0]))
87+
y.append(record.feature(self.config.predict.NAME))
88+
# Use self.logger to report how many records are being used for training
89+
self.logger.debug("Number of input records: %d", len(x))
90+
# Save m, b, and accuracy
91+
self.storage["regression_line"] = best_fit_line(x, y)
92+
93+
async def accuracy(self, sources: Sources) -> Accuracy:
94+
# Load saved regression line
95+
regression_line = self.storage.get("regression_line", None)
96+
# Ensure the model has been trained before we try to make a prediction
97+
if regression_line is None:
98+
raise ModelNotTrained("Train model before assessing for accuracy.")
99+
# Accuracy is the last element in regression_line, which is a list of
100+
# three values: m, b, and accuracy.
101+
return Accuracy(regression_line[2])
102+
103+
async def predict(
104+
self, records: AsyncIterator[Record]
105+
) -> AsyncIterator[Tuple[Record, Any, float]]:
106+
# Load saved regression line
107+
regression_line = self.storage.get("regression_line", None)
108+
# Ensure the model has been trained before we try to make a prediction
109+
if regression_line is None:
110+
raise ModelNotTrained("Train model before prediction.")
111+
# Expand the regression_line into named variables
112+
m, b, accuracy = regression_line
113+
# Iterate through each record that needs a prediction
114+
async for record in records:
115+
# Grab the x data from the record
116+
x = record.feature(self.features[0])
117+
# Calculate y
118+
y = m * x + b
119+
# Set the calculated value with the estimated accuracy
120+
record.predicted(self.config.predict.NAME, y, accuracy)
121+
# Yield the record to the caller
122+
yield record

model/scratch/dffml_model_scratch/slr.py

Lines changed: 0 additions & 187 deletions
This file was deleted.

model/scratch/setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@
6767
packages=find_packages(),
6868
entry_points={
6969
"dffml.model": [
70-
"scratchslr = dffml_model_scratch.slr:SLR",
7170
"scratchlgrsag = dffml_model_scratch.logisticregression:LogisticRegression",
7271
]
7372
},

model/scratch/tests/test_slr.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from dffml import train, accuracy, predict, DefFeature, Features, AsyncTestCase
55

6-
from dffml_model_scratch.slr import SLR, SLRConfig
6+
from dffml.model.slr import SLRModel, SLRModelConfig
77

88
TRAIN_DATA = [
99
[12.4, 11.2],
@@ -49,7 +49,7 @@ def setUpClass(cls):
4949
for x, y in TEST_DATA:
5050
cls.test_data.append({"X": x, "Y": y})
5151
# Create an instance of the model
52-
cls.model = SLR(
52+
cls.model = SLRModel(
5353
directory=cls.model_dir.name,
5454
predict=DefFeature("Y", float, 1),
5555
features=Features(DefFeature("X", float, 1)),

model/scratch/tests/test_slr_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ async def test_run(self):
2020
# Arguments for the model
2121
model_args = [
2222
"-model",
23-
"scratchslr",
23+
"slr",
2424
"-model-features",
2525
"Years:int:1",
2626
"-model-predict",

setup.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,5 +109,7 @@
109109
"dffml.orchestrator": ["memory = dffml.df.memory:MemoryOrchestrator"],
110110
# Databases
111111
"dffml.db": ["sqlite = dffml.db.sqlite:SqliteDatabase"],
112+
# Models
113+
"dffml.model": ["slr = dffml.model.slr:SLRModel"],
112114
},
113115
)

tests/integration/test_service_dev.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ async def test_run(self):
4444
await CLI.cli(
4545
"train",
4646
"-model",
47-
"scratchslr",
47+
"slr",
4848
"-model-features",
4949
"Years:int:1",
5050
"-model-predict",
@@ -63,7 +63,7 @@ async def test_run(self):
6363
"-features",
6464
json.dumps({"Years": 6}),
6565
"-config-model",
66-
"scratchslr",
66+
"slr",
6767
"-config-model-features",
6868
"Years:int:1",
6969
"-config-model-predict",

0 commit comments

Comments
 (0)