Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit 2b35bdd

Browse files
committed
high level: Add very abstract Python APIs
Fixes: #381 Fixes: #307 Fixes: #287 Signed-off-by: John Andersen <[email protected]>
1 parent d638e59 commit 2b35bdd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+784
-182
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
5757
- Subclasses of `BaseConfigurable` will now auto instantiate their respective
5858
config classes using `kwargs` if the config argument isn't given and keyword
5959
arguments are.
60+
- The quickstart documentation was improved as well as the structure of docs.
6061
### Fixed
6162
- CONTRIBUTING.md has `-e` in the wrong place in the getting setup section.
6263
- Since moving to auto `args()` and `config()`, BaseConfigurable no longer

dffml/__init__.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
1-
# SPDX-License-Identifier: MIT
2-
# Copyright (c) 2019 Intel Corporation
3-
from .feature import Feature
1+
# General
2+
from .high_level import train, accuracy, predict
3+
from .feature import Features, Feature, DefFeature
4+
5+
# Sources
6+
from .source.source import Sources, BaseSource, BaseSourceContext
7+
from .source.csv import CSVSource
8+
from .source.json import JSONSource
9+
10+
# Models
11+
from .model import Model, ModelContext
412

513
# Used to declare our namespace for resource discovery
614
__import__("pkg_resources").declare_namespace(__name__)

dffml/cli/ml.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from ..source.source import SubsetSources
22
from ..util.cli.arg import Arg
33
from ..util.cli.cmd import CMD
4+
from ..high_level import train, predict, accuracy
45
from ..util.cli.cmds import SourcesCMD, ModelCMD, KeysCMD
56

67

@@ -22,18 +23,14 @@ class Train(MLCMD):
2223
"""
2324

2425
async def run(self):
25-
async with self.sources as sources, self.model as model:
26-
async with sources() as sctx, model() as mctx:
27-
return await mctx.train(sctx)
26+
return await train(self.model, self.sources)
2827

2928

3029
class Accuracy(MLCMD):
3130
"""Assess model accuracy on data from given sources"""
3231

3332
async def run(self):
34-
async with self.sources as sources, self.model as model:
35-
async with sources() as sctx, model() as mctx:
36-
return float(await mctx.accuracy(sctx))
33+
return await accuracy(self.model, self.sources)
3734

3835

3936
class PredictAll(MLCMD):
@@ -47,17 +44,11 @@ class PredictAll(MLCMD):
4744
action="store_true",
4845
)
4946

50-
async def predict(self, mctx, sctx, repos):
51-
async for repo in mctx.predict(repos):
52-
yield repo
53-
if self.update:
54-
await sctx.update(repo)
55-
5647
async def run(self):
57-
async with self.sources as sources, self.model as model:
58-
async with sources() as sctx, model() as mctx:
59-
async for repo in self.predict(mctx, sctx, sctx.repos()):
60-
yield repo
48+
async for repo in predict(
49+
self.model, self.sources, update=self.update, keep_repo=True
50+
):
51+
yield repo
6152

6253

6354
class PredictRepo(PredictAll, KeysCMD):

dffml/high_level.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
"""
2+
High level abstraction interfaces to DFFML. These are probably going to be used
3+
in a lot of quick and dirty python files.
4+
"""
5+
import pathlib
6+
from typing import Union, Dict, Any
7+
8+
from .repo import Repo
9+
from .source.source import Sources, BaseSource
10+
from .source.memory import MemorySource, MemorySourceConfig
11+
12+
13+
def _repos_to_sources(*args):
14+
"""
15+
Create a memory source out of any repos passed as a variable length list.
16+
Add all sources found in the variable length list to a list of sources, and
17+
the created source containing repos, and return that list of sources.
18+
"""
19+
# If the first arg is an instance of sources, append the rest to that.
20+
if args and isinstance(args[0], Sources):
21+
sources = args[0]
22+
else:
23+
sources = Sources(
24+
*[arg for arg in args if isinstance(arg, BaseSource)]
25+
)
26+
# Repos to add to memory source
27+
repos = []
28+
# Make args mutable
29+
args = list(args)
30+
# Convert dicts to repos
31+
for i, arg in enumerate(args):
32+
if isinstance(arg, dict):
33+
arg = Repo(i, data={"features": arg})
34+
if isinstance(arg, Repo):
35+
repos.append(arg)
36+
if isinstance(arg, str) and "." in arg:
37+
filepath = pathlib.Path(arg)
38+
source = BaseSource.load(filepath.suffix.replace(".", ""))
39+
sources.append(source(filename=arg))
40+
# Create memory source if there are any repos
41+
if repos:
42+
sources.append(MemorySource(MemorySourceConfig(repos=repos)))
43+
return sources
44+
45+
46+
async def train(model, *args: Union[BaseSource, Repo, Dict[str, Any]]):
47+
"""
48+
Train a machine learning model.
49+
50+
Provide records to the model to train it. The model should be already
51+
instantiated.
52+
53+
Parameters
54+
----------
55+
model : Model
56+
Machine Learning model to use. See :doc:`/plugins/dffml_model` for
57+
models options.
58+
*args : list
59+
Input data for training. Could be a ``dict``, :py:class:`Repo`,
60+
filename, one of the data :doc:`/plugins/dffml_source`, or a filename
61+
with the extension being one of the data sources.
62+
"""
63+
sources = _repos_to_sources(*args)
64+
async with sources as sources, model as model:
65+
async with sources() as sctx, model() as mctx:
66+
return await mctx.train(sctx)
67+
68+
69+
async def accuracy(
70+
model, *args: Union[BaseSource, Repo, Dict[str, Any]]
71+
) -> float:
72+
"""
73+
Assess the accuracy of a machine learning model.
74+
75+
Provide records to the model to assess the percent accuracy of its
76+
prediction abilities. The model should be already instantiated and trained.
77+
78+
Parameters
79+
----------
80+
model : Model
81+
Machine Learning model to use. See :doc:`/plugins/dffml_model` for
82+
models options.
83+
*args : list
84+
Input data for training. Could be a ``dict``, :py:class:`Repo`,
85+
filename, one of the data :doc:`/plugins/dffml_source`, or a filename
86+
with the extension being one of the data sources.
87+
88+
Returns
89+
-------
90+
float
91+
A decimal value representing the percent of the time the model made the
92+
correct prediction. For some models this has another meaning. Please see
93+
the documentation for the model your using for further details.
94+
"""
95+
sources = _repos_to_sources(*args)
96+
async with sources as sources, model as model:
97+
async with sources() as sctx, model() as mctx:
98+
return float(await mctx.accuracy(sctx))
99+
100+
101+
async def predict(
102+
model,
103+
*args: Union[BaseSource, Repo, Dict[str, Any]],
104+
update: bool = False,
105+
keep_repo: bool = False,
106+
):
107+
"""
108+
Make a prediction using a machine learning model.
109+
110+
The model must be trained before using it to make a prediction.
111+
112+
Parameters
113+
----------
114+
model : Model
115+
Machine Learning model to use. See :doc:`/plugins/dffml_model` for
116+
models options.
117+
*args : list
118+
Input data for prediction. Could be a ``dict``, :py:class:`Repo`,
119+
filename, or one of the data :doc:`/plugins/dffml_source`.
120+
update : boolean, optional
121+
If ``True`` prediction data within records will be written back to all
122+
sources given. Defaults to ``False``.
123+
keep_repo : boolean, optional
124+
If ``True`` the results will be kept as their ``Repo`` objects instead
125+
of being converted to a ``(repo.key, features, predictions)`` tuple.
126+
Defaults to ``False``.
127+
128+
Returns
129+
-------
130+
asynciterator
131+
``Repo`` objects or ``(repo.key, features, predictions)`` tuple.
132+
"""
133+
sources = _repos_to_sources(*args)
134+
async with sources as sources, model as model:
135+
async with sources() as sctx, model() as mctx:
136+
async for repo in mctx.predict(sctx.repos()):
137+
yield repo if keep_repo else (
138+
repo.key,
139+
repo.features(),
140+
repo.predictions(),
141+
)
142+
if update:
143+
await sctx.update(repo)

dffml/model/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
>>> },
1616
>>> )
1717
"""
18-
from .model import Model
18+
from .model import Model, ModelContext
1919

2020
# Declares dffml.model as a namespace package
2121
__import__("pkg_resources").declare_namespace(__name__)

dffml/model/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from ..repo import Repo
1818
from ..source.source import Sources
1919
from ..feature import Features
20-
from ..accuracy import Accuracy
20+
from .accuracy import Accuracy
2121
from ..util.entrypoint import base_entry_point
2222

2323

dffml/noasync.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import asyncio
2+
3+
from . import high_level
4+
5+
6+
def train(*args, **kwargs):
7+
return asyncio.run(high_level.train(*args, **kwargs))
8+
9+
10+
def accuracy(*args, **kwargs):
11+
return asyncio.run(high_level.accuracy(*args, **kwargs))
12+
13+
14+
def predict(*args, **kwargs):
15+
async_gen = high_level.predict(*args, **kwargs).__aiter__()
16+
17+
loop = asyncio.new_event_loop()
18+
19+
def cleanup():
20+
loop.run_until_complete(loop.shutdown_asyncgens())
21+
loop.close()
22+
23+
while True:
24+
try:
25+
yield loop.run_until_complete(async_gen.__anext__())
26+
except StopAsyncIteration:
27+
cleanup()
28+
return
29+
except:
30+
cleanup()
31+
raise

dffml/skel/model/REPLACE_IMPORT_PACKAGE_NAME/misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from dffml.repo import Repo
99
from dffml.source.source import Sources
1010
from dffml.feature import Features
11-
from dffml.accuracy import Accuracy
11+
from dffml.model.accuracy import Accuracy
1212
from dffml.model.model import ModelContext, Model
1313
from dffml.util.entrypoint import entrypoint
1414
from dffml.base import config

dffml/source/file.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,7 @@ def zip_opener_helper(self):
112112
@contextmanager
113113
def zip_closer_helper(self):
114114
with zipfile.ZipFile(
115-
self.config.filename,
116-
self.WRITEMODE,
117-
compression=zipfile.ZIP_BZIP2,
115+
self.config.filename, self.WRITEMODE, compression=zipfile.ZIP_BZIP2
118116
) as archive:
119117
with archive.open(
120118
self.__class__.__qualname__,

0 commit comments

Comments
 (0)