Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit fb6cb4e

Browse files
authored
model: scratch: Alternate Logistic Regression implementation
1 parent 261d413 commit fb6cb4e

File tree

6 files changed

+324
-1
lines changed

6 files changed

+324
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99
- Docstrings and doctestable examples to `record.py`.
1010
- Inputs can be validated using operations
1111
- `validate` parameter in `Input` takes `Operation.instance_name`
12+
- Logistic Regression with SAG optimizer
1213
- Test tensorflow DNNEstimator documentation exaples in CI
1314
- Add python code for tensorflow DNNEstimator
1415
### Fixed

docs/plugins/dffml_model.rst

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,28 @@ dffml_model_scratch
417417
pip install dffml-model-scratch
418418
419419
420+
scratchlgr
421+
~~~~~~~~~~
422+
423+
*Official*
424+
425+
No description
426+
427+
**Args**
428+
429+
- predict: Feature
430+
431+
- Label or the value to be predicted
432+
433+
- features: List of features
434+
435+
- Features to train on
436+
437+
- directory: Path
438+
439+
- default: ~/.cache/dffml/scratch
440+
- Directory where state should be saved
441+
420442
scratchslr
421443
~~~~~~~~~~
422444

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
import pathlib
2+
from typing import AsyncIterator, Tuple, Any
3+
4+
import numpy as np
5+
6+
from dffml import (
7+
config,
8+
field,
9+
entrypoint,
10+
SimpleModel,
11+
ModelNotTrained,
12+
Accuracy,
13+
Feature,
14+
Features,
15+
Sources,
16+
Record,
17+
)
18+
19+
20+
@config
21+
class LogisticRegressionConfig:
22+
predict: Feature = field("Label or the value to be predicted")
23+
features: Features = field("Features to train on")
24+
directory: pathlib.Path = field(
25+
"Directory where state should be saved",
26+
default=pathlib.Path("~", ".cache", "dffml", "scratch"),
27+
)
28+
29+
30+
@entrypoint("scratchlgr")
31+
class LogisticRegression(SimpleModel):
32+
33+
# The configuration class needs to be set as the CONFIG property
34+
CONFIG = LogisticRegressionConfig
35+
# Logistic Regression only supports training on a single feature
36+
NUM_SUPPORTED_FEATURES = 1
37+
# We only support single dimensional values, non-matrix / array
38+
SUPPORTED_LENGTHS = [1]
39+
40+
def __init__(self, config):
41+
super().__init__(config)
42+
self.xData = np.array([])
43+
self.yData = np.array([])
44+
45+
@property
46+
def separating_line(self):
47+
"""
48+
Load separating_line from disk, if it hasn't been set yet, return None
49+
"""
50+
return self.storage.get("separating_line", None)
51+
52+
@separating_line.setter
53+
def separating_line(self, rline):
54+
"""
55+
Set separating_line in self.storage so it will be saved to disk
56+
"""
57+
self.storage["separating_line"] = rline
58+
59+
def predict_input(self, x):
60+
"""
61+
Use the regression
62+
line to make a prediction by returning ``m * x + b``.
63+
"""
64+
prediction = self.separating_line[0] * x + self.separating_line[1]
65+
if prediction > 0.5:
66+
prediction = 1
67+
else:
68+
prediction = 0
69+
self.logger.debug(
70+
"Predicted Value of {} {}:".format(
71+
self.config.predict.NAME, prediction
72+
)
73+
)
74+
return prediction
75+
76+
def best_fit_line(self):
77+
self.logger.debug(
78+
"Number of input records: {}".format(len(self.xData))
79+
)
80+
x = self.xData
81+
y = self.yData
82+
learning_rate = 0.01
83+
w = 0.01
84+
b = 0.0
85+
for _ in range(1, 1500):
86+
z = w * x + b
87+
val = -np.multiply(y, z)
88+
num = -np.multiply(y, np.exp(val))
89+
den = 1 + np.exp(val)
90+
f = num / den
91+
gradJ = np.sum(x * f)
92+
w = w - learning_rate * gradJ / len(x)
93+
error = 0
94+
for x_id in range(len(x)):
95+
yhat = x[x_id] * w + b > 0.5
96+
if yhat:
97+
yhat = 1
98+
else:
99+
yhat = 0
100+
if yhat != y[x_id]:
101+
error += 1
102+
accuracy = 1 - (error / len(x))
103+
return (w, b, accuracy)
104+
105+
async def train(self, sources: Sources):
106+
async for record in sources.with_features(
107+
self.features + [self.config.predict.NAME]
108+
):
109+
feature_data = record.features(
110+
self.features + [self.config.predict.NAME]
111+
)
112+
self.xData = np.append(self.xData, feature_data[self.features[0]])
113+
self.yData = np.append(
114+
self.yData, feature_data[self.config.predict.NAME]
115+
)
116+
self.separating_line = self.best_fit_line()
117+
118+
async def accuracy(self, sources: Sources) -> Accuracy:
119+
# Ensure the model has been trained before we try to make a prediction
120+
if self.separating_line is None:
121+
raise ModelNotTrained("Train model before assessing for accuracy.")
122+
accuracy_value = self.separating_line[2]
123+
return Accuracy(accuracy_value)
124+
125+
async def predict(
126+
self, records: AsyncIterator[Record]
127+
) -> AsyncIterator[Tuple[Record, Any, float]]:
128+
# Ensure the model has been trained before we try to make a prediction
129+
if self.separating_line is None:
130+
raise ModelNotTrained("Train model before prediction.")
131+
target = self.config.predict.NAME
132+
async for record in records:
133+
feature_data = record.features(self.features)
134+
record.predicted(
135+
target,
136+
self.predict_input(feature_data[self.features[0]]),
137+
self.separating_line[2],
138+
)
139+
yield record

model/scratch/setup.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,5 +65,10 @@
6565
],
6666
install_requires=INSTALL_REQUIRES,
6767
packages=find_packages(),
68-
entry_points={"dffml.model": ["scratchslr = dffml_model_scratch.slr:SLR"]},
68+
entry_points={
69+
"dffml.model": [
70+
"scratchslr = dffml_model_scratch.slr:SLR",
71+
"scratchlgr = dffml_model_scratch.logisticregression:LogisticRegression",
72+
]
73+
},
6974
)

model/scratch/tests/test_lgr.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import tempfile
2+
import unittest
3+
4+
from dffml import train, accuracy, predict, DefFeature, Features, AsyncTestCase
5+
6+
from dffml_model_scratch.logisticregression import (
7+
LogisticRegressionConfig,
8+
LogisticRegression,
9+
)
10+
11+
TRAIN_DATA = [
12+
[0.90, 0],
13+
[0.22, 0],
14+
[0.34, 0],
15+
[0.09, 0],
16+
[0.76, 0],
17+
[0.29, 0],
18+
[0.98, 0],
19+
[0.47, 0],
20+
[0.51, 1],
21+
[0.60, 1],
22+
[0.97, 1],
23+
[0.82, 1],
24+
[0.24, 1],
25+
[0.19, 1],
26+
[0.79, 1],
27+
[0.92, 1],
28+
]
29+
30+
TEST_DATA = [
31+
[0.28, 1],
32+
[0.94, 0],
33+
[0.64, 1],
34+
[0.37, 1],
35+
[0.65, 0],
36+
[0.09, 1],
37+
[0.22, 0],
38+
]
39+
40+
41+
class TestLogisticRegression(AsyncTestCase):
42+
@classmethod
43+
def setUpClass(cls):
44+
# Create a temporary directory to store the trained model
45+
cls.model_dir = tempfile.TemporaryDirectory()
46+
# Create the training data
47+
cls.train_data = []
48+
for x, y in TRAIN_DATA:
49+
cls.train_data.append({"X": x, "Y": y})
50+
# Create the test data
51+
cls.test_data = []
52+
for x, y in TEST_DATA:
53+
cls.test_data.append({"X": x, "Y": y})
54+
# Create an instance of the model
55+
cls.model = LogisticRegression(
56+
directory=cls.model_dir.name,
57+
predict=DefFeature("Y", float, 1),
58+
features=Features(DefFeature("X", float, 1)),
59+
)
60+
61+
@classmethod
62+
def tearDownClass(cls):
63+
# Remove the temporary directory where the trained model was stored
64+
cls.model_dir.cleanup()
65+
66+
async def test_00_train(self):
67+
# Train the model on the training data
68+
await train(self.model, *self.train_data)
69+
70+
async def test_01_accuracy(self):
71+
# Use the test data to assess the model's accuracy
72+
res = await accuracy(self.model, *self.test_data)
73+
# Ensure the accuracy is above 80%
74+
self.assertTrue(0.0 <= res <= 1.0)
75+
76+
async def test_02_predict(self):
77+
# Get the prediction for each piece of test data
78+
async for i, features, prediction in predict(
79+
self.model, *self.test_data
80+
):
81+
# Grab the correct value
82+
correct = self.test_data[i]["Y"]
83+
# Grab the predicted value
84+
prediction = prediction["Y"]["value"]
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import csv
2+
import json
3+
import pathlib
4+
import contextlib
5+
6+
from dffml.cli.cli import CLI
7+
from dffml.util.asynctestcase import IntegrationCLITestCase
8+
9+
10+
class TestLogisticRegression(IntegrationCLITestCase):
11+
async def test_run(self):
12+
# Make a temporary directory to store the model
13+
directory = self.mktempdir()
14+
# Create the csv data
15+
d_temp = {True: 1, False: 0}
16+
data_filename = self.mktempfile() + ".csv"
17+
with open(pathlib.Path(data_filename), "w") as data_file:
18+
writer = csv.writer(data_file, delimiter=",")
19+
writer.writerow(["f1", "ans"])
20+
writer.writerows(
21+
[[i / 10, d_temp[i / 10 > 0.5]] for i in range(0, 10)]
22+
)
23+
# Arguments for the model
24+
model_args = [
25+
"-model",
26+
"scratchlgr",
27+
"-model-features",
28+
"f1:int:1",
29+
"-model-predict",
30+
"ans:int:1",
31+
"-model-directory",
32+
directory,
33+
]
34+
# Train the model
35+
await CLI.cli(
36+
"train",
37+
*model_args,
38+
"-sources",
39+
"training_data=csv",
40+
"-source-filename",
41+
data_filename,
42+
)
43+
# Assess accuracy
44+
await CLI.cli(
45+
"accuracy",
46+
*model_args,
47+
"-sources",
48+
"test_data=csv",
49+
"-source-filename",
50+
data_filename,
51+
)
52+
with contextlib.redirect_stdout(self.stdout):
53+
# Make prediction
54+
await CLI._main(
55+
"predict",
56+
"all",
57+
*model_args,
58+
"-sources",
59+
"predict_data=csv",
60+
"-source-filename",
61+
data_filename,
62+
)
63+
results = json.loads(self.stdout.getvalue())
64+
self.assertTrue(isinstance(results, list))
65+
self.assertEqual(len(results), 10)
66+
for i, result in enumerate(results):
67+
self.assertIn("prediction", result)
68+
result = result["prediction"]
69+
self.assertIn("ans", result)
70+
result = result["ans"]
71+
self.assertIn("value", result)
72+
result = result["value"]

0 commit comments

Comments
 (0)