Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit 990ad6a

Browse files
John Andersenpdxjohnny
authored andcommitted
model: Predict only yields repo
Signed-off-by: John Andersen <[email protected]>
1 parent d862967 commit 990ad6a

File tree

9 files changed

+37
-30
lines changed

9 files changed

+37
-30
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
6363
- shouldi example runs bandit now in addition to safety
6464
- The way safety gets called
6565
- Switched documentation to Read The Docs theme
66+
- Models yield only a repo object instead of the value and confidence of the
67+
prediction as well. Models are not responsible for calling the predicted
68+
method on the repo. This will ease the process of making predict feature
69+
specific.
6670
### Fixed
6771
- Docs get version from dffml.version.VERSION.
6872
- FileSource zipfiles are wrapped with TextIOWrapper because CSVSource expects

dffml/model/model.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,7 @@ async def accuracy(self, sources: Sources) -> Accuracy:
5252
raise NotImplementedError()
5353

5454
@abc.abstractmethod
55-
async def predict(
56-
self, repos: AsyncIterator[Repo]
57-
) -> AsyncIterator[Tuple[Repo, Any, float]]:
55+
async def predict(self, repos: AsyncIterator[Repo]) -> AsyncIterator[Repo]:
5856
"""
5957
Uses trained data to make a prediction about the quality of a repo.
6058
"""

dffml/repo.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
Information on the software to evaluate is stored in a Repo instance.
55
"""
66
import os
7+
import warnings
78
from datetime import datetime
89
from typing import Optional, List, Dict, Any, AsyncIterator
910

@@ -21,7 +22,7 @@ class RepoPrediction(dict):
2122

2223
EXPORTED = ["value", "confidence"]
2324

24-
def __init__(self, *, confidence: float = 0.0, value: Any = "") -> None:
25+
def __init__(self, *, confidence: float = 0.0, value: Any = None) -> None:
2526
self["confidence"] = confidence
2627
self["value"] = value
2728

@@ -39,7 +40,7 @@ def dict(self):
3940
return self
4041

4142
def __len__(self):
42-
if self["confidence"] == 0.0 and not self["value"]:
43+
if self["confidence"] == 0.0 and self["value"] is None:
4344
return 0
4445
return 2
4546

@@ -128,6 +129,15 @@ def __init__(
128129
self.extra = extra
129130

130131
def dict(self):
132+
# TODO Remove dict method in favor of export
133+
warnings.warn(
134+
"dict method will be removed in favor of export",
135+
DeprecationWarning,
136+
stacklevel=2,
137+
)
138+
return self.export()
139+
140+
def export(self):
131141
data = self.data.dict()
132142
data["extra"] = self.extra
133143
return data

model/scikit/dffml_model_scikit/scikit_base.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,7 @@ async def accuracy(self, sources: Sources) -> Accuracy:
9191
self.logger.debug("Model Accuracy: {}".format(self.confidence))
9292
return self.confidence
9393

94-
async def predict(
95-
self, repos: AsyncIterator[Repo]
96-
) -> AsyncIterator[Tuple[Repo, Any, float]]:
94+
async def predict(self, repos: AsyncIterator[Repo]) -> AsyncIterator[Repo]:
9795
if self.confidence is None:
9896
raise ValueError("Model Not Trained")
9997
async for repo in repos:
@@ -107,7 +105,8 @@ async def predict(
107105
self.clf.predict(predict),
108106
)
109107
)
110-
yield repo, self.clf.predict(predict)[0], self.confidence
108+
repo.predicted(self.clf.predict(predict)[0], self.confidence)
109+
yield repo
111110

112111

113112
class Scikit(Model):

model/scikit/tests/test_scikit.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,8 @@ async def test_01_accuracy(self):
9292
async def test_02_predict(self):
9393
async with self.sources as sources, self.features as features, self.model as model:
9494
async with sources() as sctx, model(features) as mctx:
95-
async for repo, prediction, confidence in mctx.predict(
96-
sctx.repos()
97-
):
95+
async for repo in mctx.predict(sctx.repos()):
96+
prediction = repo.prediction().value
9897
if self.MODEL_TYPE is "CLASSIFICATION":
9998
self.assertIn(prediction, [2, 4])
10099
elif self.MODEL_TYPE is "REGRESSION":

model/scratch/dffml_model_scratch/slr.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,14 +114,14 @@ async def accuracy(self, sources: Sources) -> Accuracy:
114114
accuracy_value = self.regression_line[2]
115115
return Accuracy(accuracy_value)
116116

117-
async def predict(
118-
self, repos: AsyncIterator[Repo]
119-
) -> AsyncIterator[Tuple[Repo, Any, float]]:
117+
async def predict(self, repos: AsyncIterator[Repo]) -> AsyncIterator[Repo]:
120118
async for repo in repos:
121119
feature_data = repo.features(self.features)
122-
yield repo, await self.predict_input(
123-
feature_data[self.features[0]]
124-
), self.regression_line[2]
120+
repo.predicted(
121+
await self.predict_input(feature_data[self.features[0]]),
122+
self.regression_line[2],
123+
)
124+
yield repo
125125

126126

127127
@entry_point("slr")

model/scratch/tests/test_slr.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,10 @@ async def test_context(self):
6868
res = await mctx.accuracy(sctx)
6969
self.assertTrue(0.0 <= res < 1.0)
7070
# Test predict
71-
async for repo, prediction, confidence in mctx.predict(
72-
sctx.repos()
73-
):
71+
async for repo in mctx.predict(sctx.repos()):
7472
correct = FEATURE_DATA[int(repo.src_url)][1]
7573
# Comparison of correct to prediction to make sure prediction is within a reasonable range
74+
prediction = repo.prediction().value
7675
self.assertGreater(prediction, correct - (correct * 0.10))
7776
self.assertLess(prediction, correct + (correct * 0.10))
7877

@@ -90,9 +89,8 @@ async def test_01_accuracy(self):
9089
async def test_02_predict(self):
9190
async with self.sources as sources, self.features as features, self.model as model:
9291
async with sources() as sctx, model(features) as mctx:
93-
async for repo, prediction, confidence in mctx.predict(
94-
sctx.repos()
95-
):
92+
async for repo in mctx.predict(sctx.repos()):
9693
correct = FEATURE_DATA[int(repo.src_url)][1]
94+
prediction = repo.prediction().value
9795
self.assertGreater(prediction, correct - (correct * 0.10))
9896
self.assertLess(prediction, correct + (correct * 0.10))

model/tensorflow/dffml_model_tensorflow/dnnc.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -308,9 +308,7 @@ async def accuracy(self, sources: Sources) -> Accuracy:
308308
accuracy_score = self.model.evaluate(input_fn=input_fn)
309309
return Accuracy(accuracy_score["accuracy"])
310310

311-
async def predict(
312-
self, repos: AsyncIterator[Repo]
313-
) -> AsyncIterator[Tuple[Repo, Any, float]]:
311+
async def predict(self, repos: AsyncIterator[Repo]) -> AsyncIterator[Repo]:
314312
"""
315313
Uses trained data to make a prediction about the quality of a repo.
316314
"""
@@ -323,7 +321,8 @@ async def predict(
323321
for repo, pred_dict in zip(predict, predictions):
324322
class_id = pred_dict["class_ids"][0]
325323
probability = pred_dict["probabilities"][class_id]
326-
yield repo, self.cids[class_id], probability
324+
repo.predicted(self.cids[class_id], probability)
325+
yield repo
327326

328327

329328
@entry_point("tfdnnc")

model/tensorflow/tests/test_dnnc.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,5 +111,5 @@ async def test_02_predict(self):
111111
async with sources() as sctx, model(features) as mctx:
112112
res = [repo async for repo in mctx.predict(sctx.repos())]
113113
self.assertEqual(len(res), 1)
114-
self.assertEqual(res[0][0].src_url, a.src_url)
115-
self.assertTrue(res[0][1])
114+
self.assertEqual(res[0].src_url, a.src_url)
115+
self.assertTrue(res[0].prediction().value)

0 commit comments

Comments
 (0)