Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit 3a32d98

Browse files
committed
skel: model: Convert to SimpleModel
Signed-off-by: John Andersen <[email protected]>
1 parent 55bb801 commit 3a32d98

File tree

2 files changed

+168
-111
lines changed

2 files changed

+168
-111
lines changed
Lines changed: 108 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,122 @@
1-
# SPDX-License-Identifier: MIT
2-
# Copyright (c) 2019 Intel Corporation
3-
"""
4-
Description of what this model does
5-
"""
6-
from typing import AsyncIterator, Tuple, Any, List
7-
8-
from dffml.record import Record
9-
from dffml.source.source import Sources
10-
from dffml.feature import Features
11-
from dffml.model.accuracy import Accuracy
12-
from dffml.model.model import ModelContext, Model
13-
from dffml.util.entrypoint import entrypoint
14-
from dffml.base import config
1+
import pathlib
2+
import statistics
3+
from typing import AsyncIterator, Tuple, Any, Type, List
4+
5+
from dffml import (
6+
config,
7+
field,
8+
entrypoint,
9+
SimpleModel,
10+
ModelNotTrained,
11+
Accuracy,
12+
Feature,
13+
Features,
14+
Sources,
15+
Record,
16+
)
17+
18+
19+
def matrix_subtract(one, two):
20+
return [
21+
one_element - two_element for one_element, two_element in zip(one, two)
22+
]
23+
24+
25+
def matrix_multiply(one, two):
26+
return [
27+
one_element * two_element for one_element, two_element in zip(one, two)
28+
]
29+
30+
31+
def squared_error(y, line):
32+
return sum(map(lambda element: element ** 2, matrix_subtract(y, line)))
33+
34+
35+
def coeff_of_deter(y, regression_line):
36+
y_mean_line = [statistics.mean(y)] * len(y)
37+
squared_error_mean = squared_error(y, y_mean_line)
38+
squared_error_regression = squared_error(y, regression_line)
39+
return 1 - (squared_error_regression / squared_error_mean)
40+
41+
42+
def best_fit_line(x, y):
43+
mean_x = statistics.mean(x)
44+
mean_y = statistics.mean(y)
45+
m = (mean_x * mean_y - statistics.mean(matrix_multiply(x, y))) / (
46+
(mean_x ** 2) - statistics.mean(matrix_multiply(x, x))
47+
)
48+
b = mean_y - (m * mean_x)
49+
regression_line = [m * x + b for x in x]
50+
accuracy = coeff_of_deter(y, regression_line)
51+
return (m, b, accuracy)
1552

1653

1754
@config
1855
class MiscModelConfig:
19-
# This model never uses the directory, but chances are if you want to save
20-
# and load data from disk you will need to
21-
directory: str
22-
classifications: List[str]
23-
features: Features
56+
predict: Feature = field("Label or the value to be predicted")
57+
features: Features = field("Features to train on. For SLR only 1 allowed")
58+
directory: pathlib.Path = field(
59+
"Directory where state should be saved",
60+
default=pathlib.Path("~", ".cache", "dffml", "miscmodel"),
61+
)
2462

2563

26-
class MiscModelContext(ModelContext):
27-
"""
28-
Model wraping model_name API
29-
"""
64+
@entrypoint("miscmodel")
65+
class MiscModel(SimpleModel):
66+
# The configuration class needs to be set as the CONFIG property
67+
CONFIG: Type[MiscModelConfig] = MiscModelConfig
68+
# Simple Linear Regression only supports training on a single feature.
69+
# Do not define NUM_SUPPORTED_FEATURES if you support arbitrary numbers of
70+
# features.
71+
NUM_SUPPORTED_FEATURES: int = 1
72+
# We only support single dimensional values, non-matrix / array
73+
# Do not define SUPPORTED_LENGTHS if you support arbitrary dimensions
74+
SUPPORTED_LENGTHS: List[int] = [1]
3075

31-
async def train(self, sources: Sources):
32-
"""
33-
Train using records as the data to learn from.
34-
"""
35-
pass
76+
async def train(self, sources: Sources) -> None:
77+
# X and Y data
78+
x = []
79+
y = []
80+
# Go through all records that have the feature we're training on and the
81+
# feature we want to predict. Since our model only supports 1 feature,
82+
# the self.features list will only have one element at index 0.
83+
async for record in sources.with_features(
84+
self.features + [self.config.predict.NAME]
85+
):
86+
x.append(record.feature(self.features[0]))
87+
y.append(record.feature(self.config.predict.NAME))
88+
# Use self.logger to report how many records are being used for training
89+
self.logger.debug("Number of input records: %d", len(x))
90+
# Save m, b, and accuracy
91+
self.storage["regression_line"] = best_fit_line(x, y)
3692

3793
async def accuracy(self, sources: Sources) -> Accuracy:
38-
"""
39-
Evaluates the accuracy of our model after training using the input records
40-
as test data.
41-
"""
42-
# Lies
43-
return 1.0
94+
# Load saved regression line
95+
regression_line = self.storage.get("regression_line", None)
96+
# Ensure the model has been trained before we try to make a prediction
97+
if regression_line is None:
98+
raise ModelNotTrained("Train model before assessing for accuracy.")
99+
# Accuracy is the last element in regression_line, which is a list of
100+
# three values: m, b, and accuracy.
101+
return Accuracy(regression_line[2])
44102

45103
async def predict(
46104
self, records: AsyncIterator[Record]
47105
) -> AsyncIterator[Tuple[Record, Any, float]]:
48-
"""
49-
Uses trained data to make a prediction about the quality of a record.
50-
"""
106+
# Load saved regression line
107+
regression_line = self.storage.get("regression_line", None)
108+
# Ensure the model has been trained before we try to make a prediction
109+
if regression_line is None:
110+
raise ModelNotTrained("Train model before prediction.")
111+
# Expand the regression_line into named variables
112+
m, b, accuracy = regression_line
113+
# Iterate through each record that needs a prediction
51114
async for record in records:
52-
yield record, self.parent.config.classifications[
53-
record.feature(self.parent.config.features.names()[0])
54-
], 1.0
55-
56-
57-
@entrypoint("misc")
58-
class MiscModel(Model):
59-
60-
CONTEXT = MiscModelContext
115+
# Grab the x data from the record
116+
x = record.feature(self.features[0])
117+
# Calculate y
118+
y = m * x + b
119+
# Set the calculated value with the estimated accuracy
120+
record.predicted(self.config.predict.NAME, y, accuracy)
121+
# Yield the record to the caller
122+
yield record
Lines changed: 60 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1,83 @@
1-
import random
21
import tempfile
3-
from typing import Type
42

5-
from dffml.record import Record, RecordData
6-
from dffml.source.source import Sources
7-
from dffml.source.memory import MemorySource, MemorySourceConfig
8-
from dffml.feature import Data, Feature, Features
9-
from dffml.util.asynctestcase import AsyncTestCase
3+
from dffml import train, accuracy, predict, DefFeature, Features, AsyncTestCase
104

115
from REPLACE_IMPORT_PACKAGE_NAME.misc import MiscModel, MiscModelConfig
126

7+
TRAIN_DATA = [
8+
[12.4, 11.2],
9+
[14.3, 12.5],
10+
[14.5, 12.7],
11+
[14.9, 13.1],
12+
[16.1, 14.1],
13+
[16.9, 14.8],
14+
[16.5, 14.4],
15+
[15.4, 13.4],
16+
[17.0, 14.9],
17+
[17.9, 15.6],
18+
[18.8, 16.4],
19+
[20.3, 17.7],
20+
[22.4, 19.6],
21+
[19.4, 16.9],
22+
[15.5, 14.0],
23+
[16.7, 14.6],
24+
]
1325

14-
class StartsWithA(Feature):
26+
TEST_DATA = [
27+
[17.3, 15.1],
28+
[18.4, 16.1],
29+
[19.2, 16.8],
30+
[17.4, 15.2],
31+
[19.5, 17.0],
32+
[19.7, 17.2],
33+
[21.2, 18.6],
34+
]
1535

16-
NAME: str = "starts_with_a"
1736

18-
def dtype(self) -> Type:
19-
return int
20-
21-
def length(self) -> int:
22-
return 1
23-
24-
async def calc(self, data: Data) -> int:
25-
return 1 if data.key.lower().startswith("a") else 0
26-
27-
28-
class TestMisc(AsyncTestCase):
37+
class TestMiscModel(AsyncTestCase):
2938
@classmethod
3039
def setUpClass(cls):
31-
cls.feature = StartsWithA()
32-
cls.features = Features(cls.feature)
40+
# Create a temporary directory to store the trained model
3341
cls.model_dir = tempfile.TemporaryDirectory()
42+
# Create the training data
43+
cls.train_data = []
44+
for x, y in TRAIN_DATA:
45+
cls.train_data.append({"X": x, "Y": y})
46+
# Create the test data
47+
cls.test_data = []
48+
for x, y in TEST_DATA:
49+
cls.test_data.append({"X": x, "Y": y})
50+
# Create an instance of the model
3451
cls.model = MiscModel(
35-
MiscModelConfig(
36-
directory=cls.model_dir.name,
37-
classifications=["not a", "a"],
38-
features=cls.features,
39-
)
40-
)
41-
cls.records = [
42-
Record(
43-
"a" + str(random.random()),
44-
data={"features": {cls.feature.NAME: 1, "string": "a"}},
45-
)
46-
for _ in range(0, 1000)
47-
]
48-
cls.records += [
49-
Record(
50-
"b" + str(random.random()),
51-
data={"features": {cls.feature.NAME: 0, "string": "not a"}},
52-
)
53-
for _ in range(0, 1000)
54-
]
55-
cls.sources = Sources(
56-
MemorySource(MemorySourceConfig(records=cls.records))
52+
directory=cls.model_dir.name,
53+
predict=DefFeature("Y", float, 1),
54+
features=Features(DefFeature("X", float, 1)),
5755
)
5856

5957
@classmethod
6058
def tearDownClass(cls):
59+
# Remove the temporary directory where the trained model was stored
6160
cls.model_dir.cleanup()
6261

6362
async def test_00_train(self):
64-
async with self.sources as sources, self.model as model:
65-
async with sources() as sctx, model() as mctx:
66-
await mctx.train(sctx)
63+
# Train the model on the training data
64+
await train(self.model, *self.train_data)
6765

6866
async def test_01_accuracy(self):
69-
async with self.sources as sources, self.model as model:
70-
async with sources() as sctx, model() as mctx:
71-
res = await mctx.accuracy(sctx)
72-
self.assertGreater(res, 0.9)
67+
# Use the test data to assess the model's accuracy
68+
res = await accuracy(self.model, *self.test_data)
69+
# Ensure the accuracy is above 80%
70+
self.assertTrue(0.8 <= res < 1.0)
7371

7472
async def test_02_predict(self):
75-
a = Record("a", data={"features": {self.feature.NAME: 1}})
76-
b = Record("not a", data={"features": {self.feature.NAME: 0}})
77-
async with Sources(
78-
MemorySource(MemorySourceConfig(records=[a, b]))
79-
) as sources, self.model as model:
80-
async with sources() as sctx, model() as mctx:
81-
num = 0
82-
async for record, prediction, confidence in mctx.predict(
83-
sctx.records()
84-
):
85-
with self.subTest(record=record):
86-
self.assertEqual(prediction, record.key)
87-
num += 1
88-
self.assertEqual(num, 2)
73+
# Get the prediction for each piece of test data
74+
async for i, features, prediction in predict(
75+
self.model, *self.test_data
76+
):
77+
# Grab the correct value
78+
correct = self.test_data[i]["Y"]
79+
# Grab the predicted value
80+
prediction = prediction["Y"]["value"]
81+
# Check that the percent error is less than 10%
82+
self.assertLess(prediction, correct * 1.1)
83+
self.assertGreater(prediction, correct * (1.0 - 0.1))

0 commit comments

Comments
 (0)