Skip to content
This repository was archived by the owner on Aug 25, 2024. It is now read-only.

Commit 03cd4b5

Browse files
committed
service: http: Support for models
* Introduced context creation API as well. Signed-off-by: John Andersen <[email protected]>
1 parent 68792bf commit 03cd4b5

File tree

6 files changed

+1126
-107
lines changed

6 files changed

+1126
-107
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
3131
- Entrypoint listing command to development service to aid in debugging issues
3232
with entrypoints.
3333
- HTTP API service to enable interacting with DFFML over HTTP. Currently
34-
includes APIs for configuring and using Sources.
34+
includes APIs for configuring and using Sources and Models.
3535
- MySQL protocol source to work with data from a MySQL protocol compatible db
3636
- shouldi example got a bandit operation which tells users not to install if
3737
there are more than 5 issues of high severity and confidence.

dffml/feature/feature.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,17 @@ def __str__(self):
159159
def __repr__(self):
160160
return "%s[%r, %d]" % (self.__str__(), self.dtype(), self.length())
161161

162+
def export(self):
163+
return {
164+
"name": self.NAME,
165+
"dtype": self.dtype().__qualname__,
166+
"length": self.length(),
167+
}
168+
169+
@classmethod
170+
def _fromdict(cls, **kwargs):
171+
return cls.load_def(**kwargs)
172+
162173
def dtype(self) -> Type:
163174
"""
164175
Models need to know a Feature's datatype.
@@ -290,6 +301,20 @@ def __init__(self, *args: Feature, timeout: int = None) -> None:
290301
def names(self) -> List[str]:
291302
return list(({feature.NAME: True for feature in self}).keys())
292303

304+
def export(self):
305+
return {feature.NAME: feature.export() for feature in self}
306+
307+
@classmethod
308+
def _fromdict(cls, **kwargs):
309+
for name, feature_def in kwargs.items():
310+
feature_def.setdefault("name", name)
311+
return cls(
312+
*[
313+
Feature._fromdict(**feature_data)
314+
for feature_data in kwargs.values()
315+
]
316+
)
317+
293318
async def evaluate(self, src: str, task: Task = None) -> Dict[str, Any]:
294319
return await asyncio.wait_for(
295320
self._evaluate(src, task=task), self.timeout

service/http/dffml_service_http/routes.py

Lines changed: 211 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@
77
from functools import partial
88
from dataclasses import dataclass
99
from contextlib import AsyncExitStack
10-
from typing import List, Union, AsyncIterator
10+
from typing import List, Union, AsyncIterator, Dict
1111

1212
from aiohttp import web
1313
import aiohttp_cors
1414

1515
from dffml.repo import Repo
1616
from dffml.base import MissingConfig
17-
from dffml.source.source import BaseSource
17+
from dffml.model import Model
18+
from dffml.feature import Features
19+
from dffml.source.source import BaseSource, SourcesContext
1820
from dffml.util.entrypoint import EntrypointNotFound
1921

2022

@@ -26,6 +28,8 @@
2628

2729
OK = {"error": None}
2830
SOURCE_NOT_LOADED = {"error": "Source not loaded"}
31+
MODEL_NOT_LOADED = {"error": "Model not loaded"}
32+
MODEL_NO_SOURCES = {"error": "No source context labels given"}
2933

3034

3135
class JSONEncoder(json.JSONEncoder):
@@ -54,7 +58,7 @@ class IterkeyEntry:
5458

5559
def sctx_route(handler):
5660
"""
57-
Ensure that the labeled sctx requested is loaded. Return the sctx
61+
Ensure that the labeled source context requested is loaded. Return the sctx
5862
if it is loaded and an error otherwise.
5963
"""
6064

@@ -72,15 +76,36 @@ async def get_sctx(self, request):
7276
return get_sctx
7377

7478

79+
def mctx_route(handler):
80+
"""
81+
Ensure that the labeled model context requested is loaded. Return the mctx
82+
if it is loaded and an error otherwise.
83+
"""
84+
85+
@wraps(handler)
86+
async def get_mctx(self, request):
87+
mctx = request.app["model_contexts"].get(
88+
request.match_info["label"], None
89+
)
90+
if mctx is None:
91+
return web.json_response(
92+
MODEL_NOT_LOADED, status=HTTPStatus.NOT_FOUND
93+
)
94+
return await handler(self, request, mctx)
95+
96+
return get_mctx
97+
98+
7599
class Routes:
76100
@web.middleware
77101
async def error_middleware(self, request, handler):
78102
try:
79103
return await handler(request)
80104
except web.HTTPException as error:
81-
return web.json_response(
82-
{"error": error.reason}, status=error.status
83-
)
105+
response = {"error": error.reason}
106+
if error.text is not None:
107+
response["error"] = error.text
108+
return web.json_response(response, status=error.status)
84109
except Exception as error: # pragma: no cov
85110
self.logger.error(
86111
"ERROR handling %s: %s",
@@ -160,6 +185,9 @@ async def configure_source(self, request):
160185
try:
161186
source = source.withconfig(config)
162187
except MissingConfig as error:
188+
self.logger.error(
189+
f"failed to configure source {source_name}: {error}"
190+
)
163191
return web.json_response(
164192
{"error": str(error)}, status=HTTPStatus.BAD_REQUEST
165193
)
@@ -168,15 +196,102 @@ async def configure_source(self, request):
168196
exit_stack = request.app["exit_stack"]
169197
source = await exit_stack.enter_async_context(source)
170198
request.app["sources"][label] = source
171-
sctx = await exit_stack.enter_async_context(source())
172-
request.app["source_contexts"][label] = sctx
199+
200+
return web.json_response(OK)
201+
202+
async def context_source(self, request):
203+
label = request.match_info["label"]
204+
ctx_label = request.match_info["ctx_label"]
205+
206+
if not label in request.app["sources"]:
207+
return web.json_response(
208+
{"error": f"{label} source not found"},
209+
status=HTTPStatus.NOT_FOUND,
210+
)
211+
212+
# Enter the source context and pass the features
213+
exit_stack = request.app["exit_stack"]
214+
source = request.app["sources"][label]
215+
mctx = await exit_stack.enter_async_context(source())
216+
request.app["source_contexts"][ctx_label] = mctx
217+
218+
return web.json_response(OK)
219+
220+
async def list_models(self, request):
221+
return web.json_response(
222+
{
223+
model.ENTRY_POINT_ORIG_LABEL: model.args({})
224+
for model in Model.load()
225+
},
226+
dumps=partial(json.dumps, cls=JSONEncoder),
227+
)
228+
229+
async def configure_model(self, request):
230+
model_name = request.match_info["model"]
231+
label = request.match_info["label"]
232+
233+
config = await request.json()
234+
235+
try:
236+
model = Model.load_labeled(f"{label}={model_name}")
237+
except EntrypointNotFound as error:
238+
self.logger.error(
239+
f"/configure/model/ failed to load model: {error}"
240+
)
241+
return web.json_response(
242+
{"error": f"model {model_name} not found"},
243+
status=HTTPStatus.NOT_FOUND,
244+
)
245+
246+
try:
247+
model = model.withconfig(config)
248+
except MissingConfig as error:
249+
self.logger.error(
250+
f"failed to configure model {model_name}: {error}"
251+
)
252+
return web.json_response(
253+
{"error": str(error)}, status=HTTPStatus.BAD_REQUEST
254+
)
255+
256+
# DFFML objects all follow a double context entry pattern
257+
exit_stack = request.app["exit_stack"]
258+
model = await exit_stack.enter_async_context(model)
259+
request.app["models"][label] = model
260+
261+
return web.json_response(OK)
262+
263+
async def context_model(self, request):
264+
label = request.match_info["label"]
265+
ctx_label = request.match_info["ctx_label"]
266+
267+
if not label in request.app["models"]:
268+
return web.json_response(
269+
{"error": f"{label} model not found"},
270+
status=HTTPStatus.NOT_FOUND,
271+
)
272+
273+
features_dict = await request.json()
274+
275+
try:
276+
features = Features._fromdict(**features_dict)
277+
except:
278+
return web.json_response(
279+
{"error": "Incorrect format for features"},
280+
status=HTTPStatus.BAD_REQUEST,
281+
)
282+
283+
# Enter the model context and pass the features
284+
exit_stack = request.app["exit_stack"]
285+
model = request.app["models"][label]
286+
mctx = await exit_stack.enter_async_context(model(features))
287+
request.app["model_contexts"][ctx_label] = mctx
173288

174289
return web.json_response(OK)
175290

176291
@sctx_route
177292
async def source_repo(self, request, sctx):
178293
return web.json_response(
179-
(await sctx.repo(request.match_info["key"])).dict()
294+
(await sctx.repo(request.match_info["key"])).export()
180295
)
181296

182297
@sctx_route
@@ -232,7 +347,7 @@ async def source_repos(self, request, sctx):
232347
return web.json_response(
233348
{
234349
"iterkey": iterkey,
235-
"repos": {repo.src_url: repo.dict() for repo in repos},
350+
"repos": {repo.src_url: repo.export() for repo in repos},
236351
}
237352
)
238353

@@ -245,7 +360,73 @@ async def source_repos_iter(self, request, sctx):
245360
return web.json_response(
246361
{
247362
"iterkey": iterkey,
248-
"repos": {repo.src_url: repo.dict() for repo in repos},
363+
"repos": {repo.src_url: repo.export() for repo in repos},
364+
}
365+
)
366+
367+
async def get_source_contexts(self, request, sctx_label_list):
368+
sources_context = SourcesContext([])
369+
for label in sctx_label_list:
370+
sctx = request.app["source_contexts"].get(label, None)
371+
if sctx is None:
372+
raise web.HTTPNotFound(
373+
text=list(SOURCE_NOT_LOADED.values())[0],
374+
content_type="application/json",
375+
)
376+
sources_context.append(sctx)
377+
if not sources_context:
378+
raise web.HTTPBadRequest(
379+
text=list(MODEL_NO_SOURCES.values())[0],
380+
content_type="application/json",
381+
)
382+
return sources_context
383+
384+
@mctx_route
385+
async def model_train(self, request, mctx):
386+
# Get the list of source context labels to pass to mctx.train
387+
sctx_label_list = await request.json()
388+
# Get all the source contexts
389+
sources = await self.get_source_contexts(request, sctx_label_list)
390+
# Train the model on the sources
391+
await mctx.train(sources)
392+
return web.json_response(OK)
393+
394+
@mctx_route
395+
async def model_accuracy(self, request, mctx):
396+
# Get the list of source context labels to pass to mctx.train
397+
sctx_label_list = await request.json()
398+
# Get all the source contexts
399+
sources = await self.get_source_contexts(request, sctx_label_list)
400+
# Train the model on the sources
401+
return web.json_response({"accuracy": await mctx.accuracy(sources)})
402+
403+
@mctx_route
404+
async def model_predict(self, request, mctx):
405+
# TODO Provide an iterkey method for model prediction
406+
chunk_size = int(request.match_info["chunk_size"])
407+
if chunk_size != 0:
408+
return web.json_response(
409+
{"error": "Multiple request iteration not yet supported"},
410+
status=HTTPStatus.BAD_REQUEST,
411+
)
412+
# Get the repos
413+
repos: Dict[str, Repo] = {
414+
src_url: Repo(src_url, data=repo_data)
415+
for src_url, repo_data in (await request.json()).items()
416+
}
417+
# Create an async generator to feed repos
418+
async def repo_gen():
419+
for repo in repos.values():
420+
yield repo
421+
422+
# Feed them through prediction
423+
return web.json_response(
424+
{
425+
"iterkey": None,
426+
"repos": {
427+
repo.src_url: repo.export()
428+
async for repo in mctx.predict(repo_gen())
429+
},
249430
}
250431
)
251432

@@ -272,6 +453,8 @@ async def setup(self, **kwargs):
272453
self.app["sources"] = {}
273454
self.app["source_contexts"] = {}
274455
self.app["source_repos_iterkeys"] = {}
456+
self.app["models"] = {}
457+
self.app["model_contexts"] = {}
275458
self.app.update(kwargs)
276459
self.routes = [
277460
# HTTP Service specific APIs
@@ -283,6 +466,14 @@ async def setup(self, **kwargs):
283466
"/configure/source/{source}/{label}",
284467
self.configure_source,
285468
),
469+
(
470+
"GET",
471+
"/context/source/{label}/{ctx_label}",
472+
self.context_source,
473+
),
474+
("GET", "/list/models", self.list_models),
475+
("POST", "/configure/model/{model}/{label}", self.configure_model),
476+
("POST", "/context/model/{label}/{ctx_label}", self.context_model),
286477
# Source APIs
287478
("GET", "/source/{label}/repo/{key}", self.source_repo),
288479
("POST", "/source/{label}/update/{key}", self.source_update),
@@ -293,6 +484,15 @@ async def setup(self, **kwargs):
293484
self.source_repos_iter,
294485
),
295486
# TODO route to delete iterkey before iteration has completed
487+
# Model APIs
488+
("POST", "/model/{label}/train", self.model_train),
489+
("POST", "/model/{label}/accuracy", self.model_accuracy),
490+
# TODO Provide an iterkey method for model prediction
491+
(
492+
"POST",
493+
"/model/{label}/predict/{chunk_size}",
494+
self.model_predict,
495+
),
296496
]
297497
for route in self.routes:
298498
route = self.app.router.add_route(*route)

0 commit comments

Comments
 (0)