77from functools import partial
88from dataclasses import dataclass
99from contextlib import AsyncExitStack
10- from typing import List , Union , AsyncIterator
10+ from typing import List , Union , AsyncIterator , Dict
1111
1212from aiohttp import web
1313import aiohttp_cors
1414
1515from dffml .repo import Repo
1616from dffml .base import MissingConfig
17- from dffml .source .source import BaseSource
17+ from dffml .model import Model
18+ from dffml .feature import Features
19+ from dffml .source .source import BaseSource , SourcesContext
1820from dffml .util .entrypoint import EntrypointNotFound
1921
2022
2628
2729OK = {"error" : None }
2830SOURCE_NOT_LOADED = {"error" : "Source not loaded" }
31+ MODEL_NOT_LOADED = {"error" : "Model not loaded" }
32+ MODEL_NO_SOURCES = {"error" : "No source context labels given" }
2933
3034
3135class JSONEncoder (json .JSONEncoder ):
@@ -54,7 +58,7 @@ class IterkeyEntry:
5458
5559def sctx_route (handler ):
5660 """
57- Ensure that the labeled sctx requested is loaded. Return the sctx
61+ Ensure that the labeled source context requested is loaded. Return the sctx
5862 if it is loaded and an error otherwise.
5963 """
6064
@@ -72,15 +76,36 @@ async def get_sctx(self, request):
7276 return get_sctx
7377
7478
79+ def mctx_route (handler ):
80+ """
81+ Ensure that the labeled model context requested is loaded. Return the mctx
82+ if it is loaded and an error otherwise.
83+ """
84+
85+ @wraps (handler )
86+ async def get_mctx (self , request ):
87+ mctx = request .app ["model_contexts" ].get (
88+ request .match_info ["label" ], None
89+ )
90+ if mctx is None :
91+ return web .json_response (
92+ MODEL_NOT_LOADED , status = HTTPStatus .NOT_FOUND
93+ )
94+ return await handler (self , request , mctx )
95+
96+ return get_mctx
97+
98+
7599class Routes :
76100 @web .middleware
77101 async def error_middleware (self , request , handler ):
78102 try :
79103 return await handler (request )
80104 except web .HTTPException as error :
81- return web .json_response (
82- {"error" : error .reason }, status = error .status
83- )
105+ response = {"error" : error .reason }
106+ if error .text is not None :
107+ response ["error" ] = error .text
108+ return web .json_response (response , status = error .status )
84109 except Exception as error : # pragma: no cov
85110 self .logger .error (
86111 "ERROR handling %s: %s" ,
@@ -160,6 +185,9 @@ async def configure_source(self, request):
160185 try :
161186 source = source .withconfig (config )
162187 except MissingConfig as error :
188+ self .logger .error (
189+ f"failed to configure source { source_name } : { error } "
190+ )
163191 return web .json_response (
164192 {"error" : str (error )}, status = HTTPStatus .BAD_REQUEST
165193 )
@@ -168,15 +196,102 @@ async def configure_source(self, request):
168196 exit_stack = request .app ["exit_stack" ]
169197 source = await exit_stack .enter_async_context (source )
170198 request .app ["sources" ][label ] = source
171- sctx = await exit_stack .enter_async_context (source ())
172- request .app ["source_contexts" ][label ] = sctx
199+
200+ return web .json_response (OK )
201+
202+ async def context_source (self , request ):
203+ label = request .match_info ["label" ]
204+ ctx_label = request .match_info ["ctx_label" ]
205+
206+ if not label in request .app ["sources" ]:
207+ return web .json_response (
208+ {"error" : f"{ label } source not found" },
209+ status = HTTPStatus .NOT_FOUND ,
210+ )
211+
212+ # Enter the source context and pass the features
213+ exit_stack = request .app ["exit_stack" ]
214+ source = request .app ["sources" ][label ]
215+ mctx = await exit_stack .enter_async_context (source ())
216+ request .app ["source_contexts" ][ctx_label ] = mctx
217+
218+ return web .json_response (OK )
219+
220+ async def list_models (self , request ):
221+ return web .json_response (
222+ {
223+ model .ENTRY_POINT_ORIG_LABEL : model .args ({})
224+ for model in Model .load ()
225+ },
226+ dumps = partial (json .dumps , cls = JSONEncoder ),
227+ )
228+
229+ async def configure_model (self , request ):
230+ model_name = request .match_info ["model" ]
231+ label = request .match_info ["label" ]
232+
233+ config = await request .json ()
234+
235+ try :
236+ model = Model .load_labeled (f"{ label } ={ model_name } " )
237+ except EntrypointNotFound as error :
238+ self .logger .error (
239+ f"/configure/model/ failed to load model: { error } "
240+ )
241+ return web .json_response (
242+ {"error" : f"model { model_name } not found" },
243+ status = HTTPStatus .NOT_FOUND ,
244+ )
245+
246+ try :
247+ model = model .withconfig (config )
248+ except MissingConfig as error :
249+ self .logger .error (
250+ f"failed to configure model { model_name } : { error } "
251+ )
252+ return web .json_response (
253+ {"error" : str (error )}, status = HTTPStatus .BAD_REQUEST
254+ )
255+
256+ # DFFML objects all follow a double context entry pattern
257+ exit_stack = request .app ["exit_stack" ]
258+ model = await exit_stack .enter_async_context (model )
259+ request .app ["models" ][label ] = model
260+
261+ return web .json_response (OK )
262+
263+ async def context_model (self , request ):
264+ label = request .match_info ["label" ]
265+ ctx_label = request .match_info ["ctx_label" ]
266+
267+ if not label in request .app ["models" ]:
268+ return web .json_response (
269+ {"error" : f"{ label } model not found" },
270+ status = HTTPStatus .NOT_FOUND ,
271+ )
272+
273+ features_dict = await request .json ()
274+
275+ try :
276+ features = Features ._fromdict (** features_dict )
277+ except :
278+ return web .json_response (
279+ {"error" : "Incorrect format for features" },
280+ status = HTTPStatus .BAD_REQUEST ,
281+ )
282+
283+ # Enter the model context and pass the features
284+ exit_stack = request .app ["exit_stack" ]
285+ model = request .app ["models" ][label ]
286+ mctx = await exit_stack .enter_async_context (model (features ))
287+ request .app ["model_contexts" ][ctx_label ] = mctx
173288
174289 return web .json_response (OK )
175290
176291 @sctx_route
177292 async def source_repo (self , request , sctx ):
178293 return web .json_response (
179- (await sctx .repo (request .match_info ["key" ])).dict ()
294+ (await sctx .repo (request .match_info ["key" ])).export ()
180295 )
181296
182297 @sctx_route
@@ -232,7 +347,7 @@ async def source_repos(self, request, sctx):
232347 return web .json_response (
233348 {
234349 "iterkey" : iterkey ,
235- "repos" : {repo .src_url : repo .dict () for repo in repos },
350+ "repos" : {repo .src_url : repo .export () for repo in repos },
236351 }
237352 )
238353
@@ -245,7 +360,73 @@ async def source_repos_iter(self, request, sctx):
245360 return web .json_response (
246361 {
247362 "iterkey" : iterkey ,
248- "repos" : {repo .src_url : repo .dict () for repo in repos },
363+ "repos" : {repo .src_url : repo .export () for repo in repos },
364+ }
365+ )
366+
367+ async def get_source_contexts (self , request , sctx_label_list ):
368+ sources_context = SourcesContext ([])
369+ for label in sctx_label_list :
370+ sctx = request .app ["source_contexts" ].get (label , None )
371+ if sctx is None :
372+ raise web .HTTPNotFound (
373+ text = list (SOURCE_NOT_LOADED .values ())[0 ],
374+ content_type = "application/json" ,
375+ )
376+ sources_context .append (sctx )
377+ if not sources_context :
378+ raise web .HTTPBadRequest (
379+ text = list (MODEL_NO_SOURCES .values ())[0 ],
380+ content_type = "application/json" ,
381+ )
382+ return sources_context
383+
384+ @mctx_route
385+ async def model_train (self , request , mctx ):
386+ # Get the list of source context labels to pass to mctx.train
387+ sctx_label_list = await request .json ()
388+ # Get all the source contexts
389+ sources = await self .get_source_contexts (request , sctx_label_list )
390+ # Train the model on the sources
391+ await mctx .train (sources )
392+ return web .json_response (OK )
393+
394+ @mctx_route
395+ async def model_accuracy (self , request , mctx ):
396+ # Get the list of source context labels to pass to mctx.train
397+ sctx_label_list = await request .json ()
398+ # Get all the source contexts
399+ sources = await self .get_source_contexts (request , sctx_label_list )
400+ # Train the model on the sources
401+ return web .json_response ({"accuracy" : await mctx .accuracy (sources )})
402+
403+ @mctx_route
404+ async def model_predict (self , request , mctx ):
405+ # TODO Provide an iterkey method for model prediction
406+ chunk_size = int (request .match_info ["chunk_size" ])
407+ if chunk_size != 0 :
408+ return web .json_response (
409+ {"error" : "Multiple request iteration not yet supported" },
410+ status = HTTPStatus .BAD_REQUEST ,
411+ )
412+ # Get the repos
413+ repos : Dict [str , Repo ] = {
414+ src_url : Repo (src_url , data = repo_data )
415+ for src_url , repo_data in (await request .json ()).items ()
416+ }
417+ # Create an async generator to feed repos
418+ async def repo_gen ():
419+ for repo in repos .values ():
420+ yield repo
421+
422+ # Feed them through prediction
423+ return web .json_response (
424+ {
425+ "iterkey" : None ,
426+ "repos" : {
427+ repo .src_url : repo .export ()
428+ async for repo in mctx .predict (repo_gen ())
429+ },
249430 }
250431 )
251432
@@ -272,6 +453,8 @@ async def setup(self, **kwargs):
272453 self .app ["sources" ] = {}
273454 self .app ["source_contexts" ] = {}
274455 self .app ["source_repos_iterkeys" ] = {}
456+ self .app ["models" ] = {}
457+ self .app ["model_contexts" ] = {}
275458 self .app .update (kwargs )
276459 self .routes = [
277460 # HTTP Service specific APIs
@@ -283,6 +466,14 @@ async def setup(self, **kwargs):
283466 "/configure/source/{source}/{label}" ,
284467 self .configure_source ,
285468 ),
469+ (
470+ "GET" ,
471+ "/context/source/{label}/{ctx_label}" ,
472+ self .context_source ,
473+ ),
474+ ("GET" , "/list/models" , self .list_models ),
475+ ("POST" , "/configure/model/{model}/{label}" , self .configure_model ),
476+ ("POST" , "/context/model/{label}/{ctx_label}" , self .context_model ),
286477 # Source APIs
287478 ("GET" , "/source/{label}/repo/{key}" , self .source_repo ),
288479 ("POST" , "/source/{label}/update/{key}" , self .source_update ),
@@ -293,6 +484,15 @@ async def setup(self, **kwargs):
293484 self .source_repos_iter ,
294485 ),
295486 # TODO route to delete iterkey before iteration has completed
487+ # Model APIs
488+ ("POST" , "/model/{label}/train" , self .model_train ),
489+ ("POST" , "/model/{label}/accuracy" , self .model_accuracy ),
490+ # TODO Provide an iterkey method for model prediction
491+ (
492+ "POST" ,
493+ "/model/{label}/predict/{chunk_size}" ,
494+ self .model_predict ,
495+ ),
296496 ]
297497 for route in self .routes :
298498 route = self .app .router .add_route (* route )
0 commit comments