Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit cfd37f9

Browse files
yashlambaJohn Andersen
authored andcommitted
model: Added ModelNotTrained Error
Fixes: #125
1 parent 8ef81ce commit cfd37f9

File tree

4 files changed

+25
-12
lines changed

4 files changed

+25
-12
lines changed

dffml/model/model.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222
from ..util.entrypoint import Entrypoint, base_entry_point
2323

2424

25+
class ModelNotTrained(Exception):
26+
pass
27+
28+
2529
class ModelConfig(BaseConfig, NamedTuple):
2630
directory: str
2731

model/scikit/dffml_model_scikit/scikit_base.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66
import os
77
import json
8+
import math
89
import hashlib
910
from pathlib import Path
1011
from typing import AsyncIterator, Tuple, Any, NamedTuple
@@ -16,7 +17,7 @@
1617
from dffml.repo import Repo
1718
from dffml.source.source import Sources
1819
from dffml.accuracy import Accuracy
19-
from dffml.model.model import ModelConfig, ModelContext, Model
20+
from dffml.model.model import ModelConfig, ModelContext, Model, ModelNotTrained
2021

2122

2223
class ScikitConfig(ModelConfig, NamedTuple):
@@ -33,7 +34,7 @@ def __init__(self, parent, features):
3334

3435
@property
3536
def confidence(self):
36-
return self.parent.saved.get(self._features_hash, None)
37+
return self.parent.saved.get(self._features_hash, float("nan"))
3738

3839
@confidence.setter
3940
def confidence(self, confidence):
@@ -60,7 +61,7 @@ async def __aenter__(self):
6061
return self
6162

6263
async def __aexit__(self, exc_type, exc_value, traceback):
63-
joblib.dump(self.clf, self._filename())
64+
pass
6465

6566
async def train(self, sources: Sources):
6667
data = []
@@ -77,6 +78,8 @@ async def train(self, sources: Sources):
7778
joblib.dump(self.clf, self._filename())
7879

7980
async def accuracy(self, sources: Sources) -> Accuracy:
81+
if not os.path.isfile(self._filename()):
82+
raise ModelNotTrained("Train model before assessing for accuracy.")
8083
data = []
8184
async for repo in sources.with_features(self.features):
8285
feature_data = repo.features(
@@ -91,9 +94,11 @@ async def accuracy(self, sources: Sources) -> Accuracy:
9194
self.logger.debug("Model Accuracy: {}".format(self.confidence))
9295
return self.confidence
9396

94-
async def predict(self, repos: AsyncIterator[Repo]) -> AsyncIterator[Repo]:
95-
if self.confidence is None:
96-
raise ValueError("Model Not Trained")
97+
async def predict(
98+
self, repos: AsyncIterator[Repo]
99+
) -> AsyncIterator[Tuple[Repo, Any, float]]:
100+
if not os.path.isfile(self._filename()):
101+
raise ModelNotTrained("Train model before prediction.")
97102
async for repo in repos:
98103
feature_data = repo.features(self.features)
99104
df = pd.DataFrame(feature_data, index=[0])

model/scratch/dffml_model_scratch/slr.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from dffml.source.source import Sources
1616
from dffml.feature import Features
1717
from dffml.accuracy import Accuracy
18-
from dffml.model.model import ModelConfig, ModelContext, Model
18+
from dffml.model.model import ModelConfig, ModelContext, Model, ModelNotTrained
1919
from dffml.util.entrypoint import entry_point
2020
from dffml.util.cli.arg import Arg
2121

@@ -110,11 +110,15 @@ async def train(self, sources: Sources):
110110

111111
async def accuracy(self, sources: Sources) -> Accuracy:
112112
if self.regression_line is None:
113-
raise ValueError("Model Not Trained")
113+
raise ModelNotTrained("Train model before assessing for accuracy.")
114114
accuracy_value = self.regression_line[2]
115115
return Accuracy(accuracy_value)
116116

117-
async def predict(self, repos: AsyncIterator[Repo]) -> AsyncIterator[Repo]:
117+
async def predict(
118+
self, repos: AsyncIterator[Repo]
119+
) -> AsyncIterator[Tuple[Repo, Any, float]]:
120+
if self.regression_line is None:
121+
raise ModelNotTrained("Train model before prediction.")
118122
async for repo in repos:
119123
feature_data = repo.features(self.features)
120124
repo.predicted(

model/tensorflow/dffml_model_tensorflow/dnnc.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from dffml.repo import Repo
1717
from dffml.feature import Feature, Features
1818
from dffml.source.source import Sources
19-
from dffml.model.model import ModelConfig, ModelContext, Model
19+
from dffml.model.model import ModelConfig, ModelContext, Model, ModelNotTrained
2020
from dffml.accuracy import Accuracy
2121
from dffml.util.entrypoint import entry_point
2222
from dffml.base import BaseConfig
@@ -300,7 +300,7 @@ async def accuracy(self, sources: Sources) -> Accuracy:
300300
as test data.
301301
"""
302302
if not os.path.isdir(self.model_dir_path):
303-
raise NotADirectoryError("Model not trained")
303+
raise ModelNotTrained("Train model before assessing for accuracy.")
304304
input_fn = await self.accuracy_input_fn(
305305
sources, batch_size=20, shuffle=False, epochs=1
306306
)
@@ -312,7 +312,7 @@ async def predict(self, repos: AsyncIterator[Repo]) -> AsyncIterator[Repo]:
312312
Uses trained data to make a prediction about the quality of a repo.
313313
"""
314314
if not os.path.isdir(self.model_dir_path):
315-
raise NotADirectoryError("Model not trained")
315+
raise ModelNotTrained("Train model before prediction.")
316316
# Create the input function
317317
input_fn, predict = await self.predict_input_fn(repos)
318318
# Makes predictions on classifications

0 commit comments

Comments
 (0)