Skip to content

Commit 575b130

Browse files
committed
create FunctionJobTaskClientService
1 parent 718f987 commit 575b130

File tree

2 files changed

+416
-2
lines changed

2 files changed

+416
-2
lines changed
Lines changed: 380 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,380 @@
1+
from dataclasses import dataclass
2+
from typing import Final
3+
4+
from common_library.exclude import as_dict_exclude_none
5+
from models_library.functions import (
6+
FunctionClass,
7+
FunctionID,
8+
FunctionInputs,
9+
FunctionJobCollectionID,
10+
FunctionJobID,
11+
FunctionJobStatus,
12+
FunctionOutputs,
13+
ProjectFunctionJob,
14+
RegisteredFunction,
15+
RegisteredFunctionJob,
16+
RegisteredFunctionJobWithStatus,
17+
SolverFunctionJob,
18+
TaskID,
19+
)
20+
from models_library.functions_errors import (
21+
FunctionExecuteAccessDeniedError,
22+
FunctionInputsValidationError,
23+
FunctionsExecuteApiAccessDeniedError,
24+
UnsupportedFunctionClassError,
25+
UnsupportedFunctionFunctionJobClassCombinationError,
26+
)
27+
from models_library.products import ProductName
28+
from models_library.projects_state import RunningState
29+
from models_library.rest_pagination import PageMetaInfoLimitOffset, PageOffsetInt
30+
from models_library.rpc_pagination import PageLimitInt
31+
from models_library.users import UserID
32+
from servicelib.celery.models import TaskUUID
33+
from servicelib.celery.task_manager import TaskManager
34+
from sqlalchemy.ext.asyncio import AsyncEngine
35+
36+
from ._service_function_jobs import FunctionJobService
37+
from ._service_functions import FunctionService
38+
from ._service_jobs import JobService
39+
from .api.routes.tasks import _get_task_filter
40+
from .exceptions.function_errors import (
41+
FunctionJobCacheNotFoundError,
42+
)
43+
from .models.domain.functions import PreRegisteredFunctionJobData
44+
from .models.schemas.jobs import JobInputs
45+
from .services_http.webserver import AuthSession
46+
from .services_rpc.storage import StorageService
47+
from .services_rpc.wb_api_server import WbApiRpcClient
48+
49+
_JOB_CREATION_TASK_STATUS_PREFIX: Final[str] = "JOB_CREATION_TASK_STATUS_"
50+
_JOB_CREATION_TASK_NOT_YET_SCHEDULED_STATUS: Final[str] = (
51+
f"{_JOB_CREATION_TASK_STATUS_PREFIX}NOT_YET_SCHEDULED"
52+
)
53+
54+
55+
def join_inputs(
56+
default_inputs: FunctionInputs | None,
57+
function_inputs: FunctionInputs | None,
58+
) -> FunctionInputs:
59+
if default_inputs is None:
60+
return function_inputs
61+
62+
if function_inputs is None:
63+
return default_inputs
64+
65+
# last dict will override defaults
66+
return {**default_inputs, **function_inputs}
67+
68+
69+
async def _celery_task_status(
70+
job_creation_task_id: TaskID | None,
71+
task_manager: TaskManager,
72+
user_id: UserID,
73+
product_name: ProductName,
74+
) -> str:
75+
if job_creation_task_id is None:
76+
return _JOB_CREATION_TASK_NOT_YET_SCHEDULED_STATUS
77+
task_filter = _get_task_filter(user_id, product_name)
78+
task_status = await task_manager.get_task_status(
79+
task_uuid=TaskUUID(job_creation_task_id), task_filter=task_filter
80+
)
81+
return f"{_JOB_CREATION_TASK_STATUS_PREFIX}{task_status.task_state}"
82+
83+
84+
@dataclass(frozen=True, kw_only=True)
85+
class FunctionJobTaskClientService:
86+
user_id: UserID
87+
product_name: ProductName
88+
_web_rpc_client: WbApiRpcClient
89+
_storage_client: StorageService
90+
_job_service: JobService
91+
_function_service: FunctionService
92+
_function_job_service: FunctionJobService
93+
_webserver_api: AuthSession
94+
_celery_task_manager: TaskManager
95+
_async_pg_engine: AsyncEngine
96+
97+
async def list_function_jobs_with_status(
98+
self,
99+
*,
100+
function: RegisteredFunction,
101+
filter_by_function_id: FunctionID | None = None,
102+
filter_by_function_job_ids: list[FunctionJobID] | None = None,
103+
filter_by_function_job_collection_id: FunctionJobCollectionID | None = None,
104+
pagination_offset: PageOffsetInt | None = None,
105+
pagination_limit: PageLimitInt | None = None,
106+
) -> tuple[
107+
list[RegisteredFunctionJobWithStatus],
108+
PageMetaInfoLimitOffset,
109+
]:
110+
"""Lists all function jobs for a user with pagination"""
111+
112+
pagination_kwargs = as_dict_exclude_none(
113+
pagination_offset=pagination_offset, pagination_limit=pagination_limit
114+
)
115+
116+
function_jobs_list_ws, meta = (
117+
await self._web_rpc_client.list_function_jobs_with_status(
118+
user_id=self.user_id,
119+
product_name=self.product_name,
120+
filter_by_function_id=filter_by_function_id,
121+
filter_by_function_job_ids=filter_by_function_job_ids,
122+
filter_by_function_job_collection_id=filter_by_function_job_collection_id,
123+
**pagination_kwargs,
124+
)
125+
)
126+
127+
for function_job_wso in function_jobs_list_ws:
128+
if (
129+
function_job_wso.status.status
130+
not in (
131+
RunningState.SUCCESS,
132+
RunningState.FAILED,
133+
)
134+
) or function_job_wso.outputs is None:
135+
function_job_wso.status = await self.inspect_function_job(
136+
function=await self._function_service.get_function(
137+
function_id=function_job_wso.function_uid,
138+
),
139+
function_job=function_job_wso,
140+
)
141+
142+
if function_job_wso.status.status == RunningState.SUCCESS:
143+
function_job_wso.outputs = await self.function_job_outputs(
144+
function_job=function_job_wso,
145+
function=function,
146+
stored_job_outputs=None,
147+
)
148+
return function_jobs_list_ws, meta
149+
150+
async def inspect_function_job(
151+
self,
152+
function: RegisteredFunction,
153+
function_job: RegisteredFunctionJob,
154+
) -> FunctionJobStatus:
155+
"""Raises FunctionJobProjectNotRegisteredError if no project is associated with job"""
156+
stored_job_status = await self._web_rpc_client.get_function_job_status(
157+
function_job_id=function_job.uid,
158+
user_id=self.user_id,
159+
product_name=self.product_name,
160+
)
161+
162+
if stored_job_status.status in (RunningState.SUCCESS, RunningState.FAILED):
163+
return stored_job_status
164+
165+
status: str
166+
if (
167+
function.function_class == FunctionClass.PROJECT
168+
and function_job.function_class == FunctionClass.PROJECT
169+
):
170+
if function_job.project_job_id is None:
171+
status = await _celery_task_status(
172+
job_creation_task_id=function_job.job_creation_task_id,
173+
task_manager=self._celery_task_manager,
174+
user_id=self.user_id,
175+
product_name=self.product_name,
176+
)
177+
else:
178+
job_status = await self._job_service.inspect_study_job(
179+
job_id=function_job.project_job_id,
180+
)
181+
status = job_status.state
182+
elif (function.function_class == FunctionClass.SOLVER) and (
183+
function_job.function_class == FunctionClass.SOLVER
184+
):
185+
if function_job.solver_job_id is None:
186+
status = await _celery_task_status(
187+
job_creation_task_id=function_job.job_creation_task_id,
188+
task_manager=self._celery_task_manager,
189+
user_id=self.user_id,
190+
product_name=self.product_name,
191+
)
192+
else:
193+
job_status = await self._job_service.inspect_solver_job(
194+
solver_key=function.solver_key,
195+
version=function.solver_version,
196+
job_id=function_job.solver_job_id,
197+
)
198+
status = job_status.state
199+
else:
200+
raise UnsupportedFunctionFunctionJobClassCombinationError(
201+
function_class=function.function_class,
202+
function_job_class=function_job.function_class,
203+
)
204+
205+
new_job_status = FunctionJobStatus(status=status)
206+
207+
return await self._web_rpc_client.update_function_job_status(
208+
function_job_id=function_job.uid,
209+
user_id=self.user_id,
210+
product_name=self.product_name,
211+
job_status=new_job_status,
212+
check_write_permissions=False,
213+
)
214+
215+
async def get_cached_function_job(
216+
self,
217+
*,
218+
function: RegisteredFunction,
219+
job_inputs: JobInputs,
220+
) -> RegisteredFunctionJob:
221+
"""
222+
N.B. this function checks access rights
223+
224+
raises FunctionsExecuteApiAccessDeniedError if user cannot execute functions
225+
raises FunctionJobCacheNotFoundError if no cached job is found
226+
227+
"""
228+
229+
user_api_access_rights = (
230+
await self._web_rpc_client.get_functions_user_api_access_rights(
231+
user_id=self.user_id, product_name=self.product_name
232+
)
233+
)
234+
if not user_api_access_rights.execute_functions:
235+
raise FunctionsExecuteApiAccessDeniedError(
236+
user_id=self.user_id,
237+
function_id=function.uid,
238+
)
239+
240+
user_permissions = await self._web_rpc_client.get_function_user_permissions(
241+
function_id=function.uid,
242+
user_id=self.user_id,
243+
product_name=self.product_name,
244+
)
245+
if not user_permissions.execute:
246+
raise FunctionExecuteAccessDeniedError(
247+
user_id=self.user_id,
248+
function_id=function.uid,
249+
)
250+
251+
if cached_function_jobs := await self._web_rpc_client.find_cached_function_jobs(
252+
function_id=function.uid,
253+
inputs=job_inputs.values,
254+
user_id=self.user_id,
255+
product_name=self.product_name,
256+
):
257+
for cached_function_job in cached_function_jobs:
258+
job_status = await self.inspect_function_job(
259+
function=function,
260+
function_job=cached_function_job,
261+
)
262+
if job_status.status == RunningState.SUCCESS:
263+
return cached_function_job
264+
265+
raise FunctionJobCacheNotFoundError
266+
267+
async def pre_register_function_job(
268+
self,
269+
*,
270+
function: RegisteredFunction,
271+
job_inputs: JobInputs,
272+
) -> PreRegisteredFunctionJobData:
273+
274+
if function.input_schema is not None:
275+
is_valid, validation_str = (
276+
await self._function_job_service.validate_function_inputs(
277+
function_id=function.uid,
278+
inputs=job_inputs.values,
279+
)
280+
)
281+
if not is_valid:
282+
raise FunctionInputsValidationError(error=validation_str)
283+
284+
if function.function_class == FunctionClass.PROJECT:
285+
job = await self._web_rpc_client.register_function_job(
286+
function_job=ProjectFunctionJob(
287+
function_uid=function.uid,
288+
title=f"Function job of function {function.uid}",
289+
description=function.description,
290+
inputs=job_inputs.values,
291+
outputs=None,
292+
project_job_id=None,
293+
job_creation_task_id=None,
294+
),
295+
user_id=self.user_id,
296+
product_name=self.product_name,
297+
)
298+
299+
elif function.function_class == FunctionClass.SOLVER:
300+
job = await self._web_rpc_client.register_function_job(
301+
function_job=SolverFunctionJob(
302+
function_uid=function.uid,
303+
title=f"Function job of function {function.uid}",
304+
description=function.description,
305+
inputs=job_inputs.values,
306+
outputs=None,
307+
solver_job_id=None,
308+
job_creation_task_id=None,
309+
),
310+
user_id=self.user_id,
311+
product_name=self.product_name,
312+
)
313+
else:
314+
raise UnsupportedFunctionClassError(
315+
function_class=function.function_class,
316+
)
317+
318+
return PreRegisteredFunctionJobData(
319+
function_job_id=job.uid,
320+
job_inputs=job_inputs,
321+
)
322+
323+
async def function_job_outputs(
324+
self,
325+
*,
326+
function: RegisteredFunction,
327+
function_job: RegisteredFunctionJob,
328+
stored_job_outputs: FunctionOutputs | None,
329+
) -> FunctionOutputs:
330+
331+
if stored_job_outputs is not None:
332+
return stored_job_outputs
333+
334+
job_status = await self.inspect_function_job(
335+
function=function, function_job=function_job
336+
)
337+
338+
if job_status.status != RunningState.SUCCESS:
339+
return None
340+
341+
if (
342+
function.function_class == FunctionClass.PROJECT
343+
and function_job.function_class == FunctionClass.PROJECT
344+
):
345+
if function_job.project_job_id is None:
346+
return None
347+
new_outputs = dict(
348+
(
349+
await self._job_service.get_study_job_outputs(
350+
study_id=function.project_id,
351+
job_id=function_job.project_job_id,
352+
)
353+
).results
354+
)
355+
elif (
356+
function.function_class == FunctionClass.SOLVER
357+
and function_job.function_class == FunctionClass.SOLVER
358+
):
359+
if function_job.solver_job_id is None:
360+
return None
361+
new_outputs = dict(
362+
(
363+
await self._job_service.get_solver_job_outputs(
364+
solver_key=function.solver_key,
365+
version=function.solver_version,
366+
job_id=function_job.solver_job_id,
367+
async_pg_engine=self._async_pg_engine,
368+
)
369+
).results
370+
)
371+
else:
372+
raise UnsupportedFunctionClassError(function_class=function.function_class)
373+
374+
return await self._web_rpc_client.update_function_job_outputs(
375+
function_job_id=function_job.uid,
376+
user_id=self.user_id,
377+
product_name=self.product_name,
378+
outputs=new_outputs,
379+
check_write_permissions=False,
380+
)

0 commit comments

Comments
 (0)