Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit f5a0e14

Browse files
committed
model: Simplified Model API with SimpleModel
Signed-off-by: John Andersen <[email protected]>
1 parent 39f88dd commit f5a0e14

File tree

4 files changed

+123
-3
lines changed

4 files changed

+123
-3
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313
- Doctestable Examples to high-level API.
1414
- Shouldi got an operation to run npm-audit on JavaScript code
1515
- Docstrings and doctestable examples for `record.py` (features and evaluated)
16+
- Simplified model API with SimpleModel
1617
### Changed
1718
- Restructured contributing documentation
1819
- Use randomly generated data for scikit tests

dffml/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from .source.json import JSONSource
1010

1111
# Models
12-
from .model import Model, ModelContext
12+
from .model import Model, ModelContext, SimpleModel, ModelNotTrained
1313

1414
# Utilities
1515
from .util.asynctestcase import AsyncTestCase

dffml/model/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@
1515
>>> },
1616
>>> )
1717
"""
18-
from .model import Model, ModelContext
18+
from .model import Model, ModelContext, SimpleModel, ModelNotTrained

dffml/model/model.py

Lines changed: 120 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
"""
88
import os
99
import abc
10-
from typing import AsyncIterator
10+
import json
11+
import hashlib
12+
import pathlib
13+
from typing import AsyncIterator, Optional
1114

1215
from ..base import (
1316
config,
@@ -82,3 +85,119 @@ def __call__(self) -> ModelContext:
8285
if directory is not None and not os.path.isdir(directory):
8386
os.makedirs(directory)
8487
return self.CONTEXT(self)
88+
89+
90+
class SimpleModelNoContext:
91+
"""
92+
No need for CONTEXT since we implement __call__
93+
"""
94+
95+
96+
class SimpleModel(Model):
97+
DTYPES = [int, float]
98+
NUM_SUPPORTED_FEATURES = -1
99+
SUPPORTED_LENGTHS = None
100+
CONTEXT = SimpleModelNoContext
101+
102+
def __init__(self, config: "BaseConfig") -> None:
103+
super().__init__(config)
104+
self.storage = {}
105+
self.features = self.applicable_features(config.features)
106+
self._in_context = 0
107+
108+
def __call__(self):
109+
return self
110+
111+
async def __aenter__(self) -> Model:
112+
self._in_context += 1
113+
# If we've already entered the model's context once, don't reload
114+
if self._in_context > 1:
115+
return self
116+
self.open()
117+
return self
118+
119+
async def __aexit__(self, exc_type, exc_value, traceback):
120+
self._in_context -= 1
121+
if not self._in_context:
122+
self.close()
123+
124+
def open(self):
125+
"""
126+
Load saved model from disk if it exists.
127+
"""
128+
# Load saved data if this is the first time we've entred the model
129+
filepath = self.disk_path(extention=".json")
130+
if filepath.is_file():
131+
self.storage = json.loads(filepath.read_text())
132+
self.logger.debug("Loaded model from %s", filepath)
133+
else:
134+
self.logger.debug("No saved model in %s", filepath)
135+
136+
def close(self):
137+
"""
138+
Save model to disk.
139+
"""
140+
filepath = self.disk_path(extention=".json")
141+
filepath.write_text(json.dumps(self.storage))
142+
self.logger.debug("Saved model to %s", filepath)
143+
144+
def disk_path(self, extention: Optional[str] = None):
145+
"""
146+
We do this for convenience of the user so they can usually just use the
147+
default directory and if they train models with different parameters
148+
this method transparently to the user creates a filename unique the that
149+
configuration of the model where data is saved and loaded.
150+
"""
151+
# Export the config to a dictionary
152+
exported = self.config._asdict()
153+
# Remove the directory from the exported dict
154+
if "directory" in exported:
155+
del exported["directory"]
156+
# Replace features with the sorted list of features
157+
if "features" in exported:
158+
exported["features"] = dict(sorted(exported["features"].items()))
159+
# Hash the exported config
160+
return pathlib.Path(
161+
self.config.directory,
162+
hashlib.sha384(json.dumps(exported).encode()).hexdigest()
163+
+ (extention if extention else ""),
164+
)
165+
166+
def applicable_features(self, features):
167+
usable = []
168+
# Check that we aren't trying to use more features than the model
169+
# supports
170+
if (
171+
self.NUM_SUPPORTED_FEATURES != -1
172+
and len(features) != self.NUM_SUPPORTED_FEATURES
173+
):
174+
msg = f"{self.__class__.__qualname__} doesn't support more than "
175+
if self.NUM_SUPPORTED_FEATURES == 1:
176+
msg += f"{self.NUM_SUPPORTED_FEATURES} feature"
177+
else:
178+
msg += f"{self.NUM_SUPPORTED_FEATURES} features"
179+
raise ValueError(msg)
180+
# Check data type and length for each feature
181+
for feature in features:
182+
if self.check_applicable_feature(feature):
183+
usable.append(feature.NAME)
184+
# Return a sorted list of feature names for consistency. In case users
185+
# provide the same list of features to applicable_features in a
186+
# different order.
187+
return sorted(usable)
188+
189+
def check_applicable_feature(self, feature):
190+
# Check the data datatype is in the list of supported data types
191+
if feature.dtype() not in self.DTYPES:
192+
msg = f"{self.__class__.__qualname__} only supports features "
193+
msg += f"with these data types: {self.DTYPES}"
194+
raise ValueError(msg)
195+
# If SUPPORTED_LENGTHS is None then all lengths are supported
196+
if self.SUPPORTED_LENGTHS is None:
197+
return True
198+
# Check that length (dimensions) of feature is supported
199+
if feature.length() not in self.SUPPORTED_LENGTHS:
200+
msg = f"{self.__class__.__qualname__} only supports "
201+
msg += f"{self.SUPPORTED_LENGTHS} dimensional values"
202+
raise ValueError(msg)
203+
return True

0 commit comments

Comments
 (0)