Skip to content

Commit 2c57d5b

Browse files
author
Andrei Neagu
committed
refacrtored core modules
1 parent d8b520b commit 2c57d5b

File tree

5 files changed

+332
-9
lines changed

5 files changed

+332
-9
lines changed
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from common_library.errors_classes import OsparcErrorMixin
2+
3+
4+
class BaseGenericSchedulerError(OsparcErrorMixin, Exception):
5+
"""base exception for this module"""
6+
7+
8+
class KeyNotFoundInHashError(BaseGenericSchedulerError):
9+
msg_template: str = "Key '{key}' not found in hash '{hash_key}'"
10+
11+
12+
class OperationAlreadyRegisteredError(BaseGenericSchedulerError):
13+
msg_template: str = "Operation '{operation_name}' already registered"
14+
15+
16+
class OperationNotFoundError(BaseGenericSchedulerError):
17+
msg_template: str = (
18+
"Operation '{operation_name}' was not found, registerd_operations='{registerd_operations}'"
19+
)
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
from abc import ABC, abstractmethod
2+
from typing import Annotated, ClassVar, Final, TypeAlias
3+
4+
from fastapi import FastAPI
5+
from pydantic import Field, NonNegativeInt, TypeAdapter, validate_call
6+
7+
from ._errors import OperationAlreadyRegisteredError, OperationNotFoundError
8+
from ._models import OperationName, StepGroupName, StepName
9+
10+
11+
class BaseStep(ABC):
12+
@classmethod
13+
def get_step_name(cls) -> StepName:
14+
return cls.__name__
15+
16+
@classmethod
17+
@abstractmethod
18+
async def create(cls, app: FastAPI) -> None:
19+
"""
20+
[mandatory] handler to be implemented with the code resposible for achieving a goal
21+
"""
22+
23+
@classmethod
24+
async def destroy(cls, app: FastAPI) -> None:
25+
"""
26+
[optional] handler resposible for celanup of resources created above.
27+
NOTE: Ensure this is successful if:
28+
- `create` is not executed
29+
- `create` is executed partially
30+
- `destory` is called multiple times
31+
"""
32+
_ = app
33+
34+
35+
StepsSubGroup: TypeAlias = Annotated[tuple[type[BaseStep], ...], Field(min_length=1)]
36+
37+
38+
class BaseStepGroup(ABC):
39+
def __init__(self, *, repeat_steps: bool) -> None:
40+
"""
41+
if repeat_steps is True, the steps in this group will be repeated forever
42+
"""
43+
self.repeat_steps = repeat_steps
44+
45+
@abstractmethod
46+
def get_step_group_name(self, *, index: NonNegativeInt) -> StepGroupName:
47+
"""returns the name of this step group"""
48+
49+
@abstractmethod
50+
def get_steps_names(self) -> list[StepName]:
51+
"""return sorted list of StepName"""
52+
53+
@abstractmethod
54+
def get_step_subgroup_to_run(self) -> StepsSubGroup:
55+
"""returns subgroups of steps to run"""
56+
57+
58+
class SingleStepGroup(BaseStepGroup):
59+
def __init__(self, step: type[BaseStep], *, repeat_steps: bool = False) -> None:
60+
self._step: type[BaseStep] = step
61+
super().__init__(repeat_steps=repeat_steps)
62+
63+
def get_step_group_name(self, *, index: NonNegativeInt) -> StepGroupName:
64+
return f"{index}S{'R' if self.repeat_steps else ''}"
65+
66+
def get_steps_names(self) -> list[StepName]:
67+
return [self._step.get_step_name()]
68+
69+
def get_step_subgroup_to_run(self) -> StepsSubGroup:
70+
return TypeAdapter(StepsSubGroup).validate_python((self._step,))
71+
72+
73+
_MIN_PARALLEL_STEPS: Final[int] = 2
74+
75+
76+
class ParallelStepGroup(BaseStepGroup):
77+
def __init__(self, *steps: type[BaseStep], repeat_steps: bool = False) -> None:
78+
79+
self._steps: list[type[BaseStep]] = list(steps)
80+
81+
super().__init__(repeat_steps=repeat_steps)
82+
83+
@property
84+
def steps(self) -> list[type[BaseStep]]:
85+
return self._steps
86+
87+
def get_step_group_name(self, *, index: NonNegativeInt) -> StepGroupName:
88+
return f"{index}P{'R' if self.repeat_steps else ''}"
89+
90+
def get_steps_names(self) -> list[StepName]:
91+
return sorted(x.get_step_name() for x in self._steps)
92+
93+
def get_step_subgroup_to_run(self) -> StepsSubGroup:
94+
return TypeAdapter(StepsSubGroup).validate_python(tuple(self._steps))
95+
96+
97+
Operation: TypeAlias = Annotated[list[BaseStepGroup], Field(min_length=1)]
98+
99+
100+
@validate_call(config={"arbitrary_types_allowed": True})
101+
def _validate_operation(operation: Operation) -> None:
102+
detected_steps_names: set[StepName] = set()
103+
104+
for k, step_group in enumerate(operation):
105+
if isinstance(step_group, ParallelStepGroup):
106+
if len(step_group.steps) < _MIN_PARALLEL_STEPS:
107+
msg = (
108+
f"{ParallelStepGroup.__name__} needs at least {_MIN_PARALLEL_STEPS} "
109+
f"steps. TIP: use {SingleStepGroup.__name__} instead."
110+
)
111+
raise ValueError(msg)
112+
113+
if k < len(operation) - 1 and step_group.repeat_steps is True:
114+
msg = f"Only the last step group can have repeat_steps=True. Error at index {k=}"
115+
raise ValueError(msg)
116+
117+
for step in step_group.get_step_subgroup_to_run():
118+
step_name = step.get_step_name()
119+
120+
if step_name in detected_steps_names:
121+
msg = f"Step {step_name=} is already used in this operation {detected_steps_names=}"
122+
raise ValueError(msg)
123+
124+
detected_steps_names.add(step_name)
125+
126+
127+
class OperationRegistry:
128+
_OPERATIONS: ClassVar[dict[str, Operation]] = {}
129+
130+
@classmethod
131+
def register(cls, operation_name: OperationName, operation: Operation) -> None:
132+
_validate_operation(operation)
133+
134+
if operation_name in cls._OPERATIONS:
135+
raise OperationAlreadyRegisteredError(operation_name=operation_name)
136+
137+
cls._OPERATIONS[operation_name] = operation
138+
139+
@classmethod
140+
def get(cls, operation_name: OperationName) -> Operation:
141+
if operation_name not in cls._OPERATIONS:
142+
raise OperationNotFoundError(
143+
operation_name=operation_name,
144+
registerd_operations=list(cls._OPERATIONS.keys()),
145+
)
146+
147+
return cls._OPERATIONS[operation_name]
148+
149+
@classmethod
150+
def unregister(cls, operation_name: OperationName) -> None:
151+
if operation_name not in cls._OPERATIONS:
152+
raise OperationNotFoundError(
153+
operation_name=operation_name,
154+
registerd_operations=list(cls._OPERATIONS.keys()),
155+
)
156+
157+
del cls._OPERATIONS[operation_name]

services/dynamic-scheduler/src/simcore_service_dynamic_scheduler/services/generic_scheduler/_store.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,14 @@
88
from servicelib.redis._utils import handle_redis_returns_union_types
99
from settings_library.redis import RedisDatabase, RedisSettings
1010

11-
from ._models import OperationName, ScheduleId, StepGroup, StepName
11+
from ._errors import KeyNotFoundInHashError
12+
from ._models import (
13+
OperationContext,
14+
OperationName,
15+
ScheduleId,
16+
StepGroupName,
17+
StepName,
18+
)
1219

1320
_SCHEDULE_NAMESPACE: Final[str] = "SCH"
1421
_STEPS_KEY: Final[str] = "STEPS"
@@ -26,7 +33,7 @@ def _get_step_hash_key(
2633
*,
2734
schedule_id: ScheduleId,
2835
operation_name: OperationName,
29-
group: StepGroup,
36+
group: StepGroupName,
3037
step_name: StepName,
3138
) -> str:
3239
# SCHEDULE_NAMESPACE:SCHEDULE_ID:STEPS:OPERATION_NAME:GROUP_INDEX:STEP_NAME:KEY
@@ -102,11 +109,14 @@ async def remove(self, hash_key: str) -> None:
102109

103110
class _UpdateScheduleDataDict(TypedDict):
104111
operation_name: NotRequired[OperationName]
112+
operation_context: NotRequired[OperationContext]
105113
group_index: NotRequired[NonNegativeInt]
106114
is_creating: NotRequired[bool]
107115

108116

109-
_DeleteScheduleDataKeys = Literal["operation_name", "group_index", "is_creating"]
117+
_DeleteScheduleDataKeys = Literal[
118+
"operation_name", "operation_context", "group_index", "is_creating"
119+
]
110120

111121

112122
class ScheduleDataStoreProxy:
@@ -118,19 +128,33 @@ def _get_hash_key(self) -> str:
118128
return _get_scheduler_data_hash_key(schedule_id=self._schedule_id)
119129

120130
@overload
121-
async def get(self, key: Literal["operation_name"]) -> str: ...
131+
async def get(self, key: Literal["operation_name"]) -> OperationName: ...
122132
@overload
123-
async def get(self, key: Literal["group_index"]) -> int: ...
133+
async def get(self, key: Literal["operation_context"]) -> OperationContext: ...
134+
@overload
135+
async def get(self, key: Literal["group_index"]) -> NonNegativeInt: ...
124136
@overload
125137
async def get(self, key: Literal["is_creating"]) -> bool: ...
126138
async def get(self, key: str) -> Any:
127-
(result,) = await self._store.get(self._get_hash_key(), key)
139+
"""raises KeyNotFoundInHashError if the key is not present in the hash"""
140+
hash_key = self._get_hash_key()
141+
(result,) = await self._store.get(hash_key, key)
142+
if result is None:
143+
raise KeyNotFoundInHashError(
144+
schedule_id=self._schedule_id, hash_key=hash_key
145+
)
128146
return result
129147

130148
@overload
131-
async def set(self, key: Literal["operation_name"], value: str) -> None: ...
149+
async def set(
150+
self, key: Literal["operation_name"], value: OperationName
151+
) -> None: ...
152+
@overload
153+
async def set(
154+
self, key: Literal["operation_context"], value: OperationContext
155+
) -> None: ...
132156
@overload
133-
async def set(self, key: Literal["group_index"], value: int) -> None: ...
157+
async def set(self, key: Literal["group_index"], value: NonNegativeInt) -> None: ...
134158
@overload
135159
async def set(self, key: Literal["is_creating"], *, value: bool) -> None: ...
136160
async def set(self, key: str, value: Any) -> None:
@@ -150,7 +174,7 @@ def __init__(
150174
store: Store,
151175
schedule_id: ScheduleId,
152176
operation_name: OperationName,
153-
group: StepGroup,
177+
group: StepGroupName,
154178
step_name: StepName,
155179
) -> None:
156180
self._store = store
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
# pylint: disable=protected-access
2+
3+
import pytest
4+
from simcore_service_dynamic_scheduler.services.generic_scheduler._errors import (
5+
OperationAlreadyRegisteredError,
6+
OperationNotFoundError,
7+
)
8+
from simcore_service_dynamic_scheduler.services.generic_scheduler._operation import (
9+
BaseStep,
10+
Operation,
11+
OperationRegistry,
12+
ParallelStepGroup,
13+
SingleStepGroup,
14+
_validate_operation,
15+
)
16+
17+
18+
class BS1(BaseStep):
19+
pass
20+
21+
22+
class BS2(BaseStep):
23+
pass
24+
25+
26+
class BS3(BaseStep):
27+
pass
28+
29+
30+
@pytest.mark.parametrize(
31+
"operation",
32+
[
33+
[
34+
SingleStepGroup(BS1),
35+
ParallelStepGroup(BS2, BS3),
36+
],
37+
[
38+
SingleStepGroup(BS1),
39+
],
40+
[
41+
SingleStepGroup(BS1),
42+
SingleStepGroup(BS2),
43+
],
44+
[
45+
SingleStepGroup(BS2),
46+
ParallelStepGroup(BS1, BS3, repeat_steps=True),
47+
],
48+
[
49+
ParallelStepGroup(BS1, BS3),
50+
SingleStepGroup(BS2, repeat_steps=True),
51+
],
52+
[
53+
SingleStepGroup(BS1, repeat_steps=True),
54+
],
55+
[
56+
ParallelStepGroup(BS1, BS3, repeat_steps=True),
57+
],
58+
],
59+
)
60+
def test_validate_operation_passes(operation: Operation):
61+
_validate_operation(operation)
62+
63+
64+
@pytest.mark.parametrize(
65+
"operation, match",
66+
[
67+
([], "List should have at least 1 item after validation"),
68+
(
69+
[
70+
SingleStepGroup(BS1, repeat_steps=True),
71+
SingleStepGroup(BS2),
72+
],
73+
"Only the last step group can have repeat_steps=True",
74+
),
75+
(
76+
[
77+
SingleStepGroup(BS1),
78+
SingleStepGroup(BS1),
79+
],
80+
f"step_name='{BS1.__name__}' is already used in this operation",
81+
),
82+
(
83+
[
84+
ParallelStepGroup(BS2, BS2),
85+
],
86+
f"step_name='{BS2.__name__}' is already used in this operation",
87+
),
88+
(
89+
[
90+
ParallelStepGroup(BS1),
91+
],
92+
f"{ParallelStepGroup.__name__} needs at least 2 steps",
93+
),
94+
],
95+
)
96+
def test_validate_operations_fails(operation: Operation, match: str):
97+
with pytest.raises(ValueError, match=match):
98+
_validate_operation(operation)
99+
100+
101+
def test_operation_registry_workflow():
102+
operation: Operation = [SingleStepGroup(BS1)]
103+
OperationRegistry.register("op1", operation)
104+
assert len(OperationRegistry._OPERATIONS) == 1
105+
106+
assert OperationRegistry.get("op1") == operation
107+
108+
OperationRegistry.unregister("op1")
109+
assert len(OperationRegistry._OPERATIONS) == 0
110+
111+
112+
def test_operation_registry_register_twice_fails():
113+
operation: Operation = [SingleStepGroup(BS1)]
114+
OperationRegistry.register("op1", operation)
115+
116+
with pytest.raises(OperationAlreadyRegisteredError):
117+
OperationRegistry.register("op1", operation)
118+
119+
with pytest.raises(OperationNotFoundError):
120+
OperationRegistry.get("non_existing")
121+
122+
with pytest.raises(OperationNotFoundError):
123+
OperationRegistry.unregister("non_existing")

0 commit comments

Comments
 (0)