Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ._core import (
cancel_operation,
get_operation_name,
restart_operation_step_stuck_during_revert,
restart_operation_step_stuck_in_manual_intervention_during_execute,
start_operation,
Expand Down Expand Up @@ -35,6 +36,7 @@
"cancel_operation",
"generic_scheduler_lifespan",
"get_operation_context_proxy",
"get_operation_name",
"get_step_group_proxy",
"get_step_store_proxy",
"Operation",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
OperationErrorType,
OperationName,
OperationToStart,
ReservedContextKeys,
ScheduleId,
StepName,
StepStatus,
Expand Down Expand Up @@ -100,6 +101,8 @@ async def start_operation(
"""start an operation by it's given name and providing an initial context"""
schedule_id: ScheduleId = f"{uuid4()}"

initial_operation_context[ReservedContextKeys.SCHEDULE_ID] = schedule_id

# check if operation is registered
operation = OperationRegistry.get_operation(operation_name)

Expand Down Expand Up @@ -213,6 +216,15 @@ async def _cancel_step(step_name: StepName, step_proxy: StepStoreProxy) -> None:
limit=PARALLEL_REQUESTS,
)

async def get_operation_name(self, schedule_id: ScheduleId) -> OperationName | None:
schedule_data_proxy = ScheduleDataStoreProxy(
store=self._store, schedule_id=schedule_id
)
try:
return await schedule_data_proxy.read("operation_name")
except NoDataFoundError:
return None

async def restart_operation_step_stuck_in_error(
self,
schedule_id: ScheduleId,
Expand Down Expand Up @@ -732,6 +744,13 @@ async def cancel_operation(app: FastAPI, schedule_id: ScheduleId) -> None:
await Core.get_from_app_state(app).cancel_operation(schedule_id)


async def get_operation_name(
app: FastAPI, schedule_id: ScheduleId
) -> OperationName | None:
"""returns the name of the operation or None if not found"""
return await Core.get_from_app_state(app).get_operation_name(schedule_id)


async def restart_operation_step_stuck_in_manual_intervention_during_execute(
app: FastAPI, schedule_id: ScheduleId, step_name: StepName
) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,23 @@ async def register_to_start_after(
schedule_id: ScheduleId,
event_type: EventType,
*,
to_start: OperationToStart,
to_start: OperationToStart | None,
) -> None:

events_proxy = OperationEventsProxy(self._store, schedule_id, event_type)
if to_start is None:
# unregister any previously registered operation
await events_proxy.delete()
_logger.debug(
"Unregistered event_type='%s' to_start for schedule_id='%s'",
event_type,
schedule_id,
)
return

# ensure operation exists
OperationRegistry.get_operation(to_start.operation_name)

events_proxy = OperationEventsProxy(self._store, schedule_id, event_type)
await events_proxy.create_or_update_multiple(
{
"initial_context": to_start.initial_context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,15 @@ def _get_after_event_manager(app: FastAPI) -> "AfterEventManager":


async def register_to_start_after_on_executed_completed(
app: FastAPI, schedule_id: ScheduleId, *, to_start: OperationToStart
app: FastAPI, schedule_id: ScheduleId, *, to_start: OperationToStart | None
) -> None:
await _get_after_event_manager(app).register_to_start_after(
schedule_id, EventType.ON_EXECUTEDD_COMPLETED, to_start=to_start
)


async def register_to_start_after_on_reverted_completed(
app: FastAPI, schedule_id: ScheduleId, *, to_start: OperationToStart
app: FastAPI, schedule_id: ScheduleId, *, to_start: OperationToStart | None
) -> None:
await _get_after_event_manager(app).register_to_start_after(
schedule_id, EventType.ON_REVERT_COMPLETED, to_start=to_start
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from enum import auto
from enum import Enum, auto
from typing import Annotated, Any, Final, TypeAlias

from models_library.basic_types import UUIDStr
Expand Down Expand Up @@ -51,3 +51,10 @@ class EventType(StrAutoEnum):
class OperationToStart:
operation_name: OperationName
initial_context: OperationContext


class ReservedContextKeys(str, Enum):
SCHEDULE_ID = "_schedule_id"


ALL_RESERVED_CONTEXT_KEYS: Final[set[str]] = {x.value for x in ReservedContextKeys}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
StepNotFoundInoperationError,
)
from ._models import (
ALL_RESERVED_CONTEXT_KEYS,
OperationName,
ProvidedOperationContext,
RequiredOperationContext,
Expand Down Expand Up @@ -245,7 +246,7 @@ def _has_abstract_methods(cls: type[object]) -> bool:


@validate_call(config={"arbitrary_types_allowed": True})
def _validate_operation( # noqa: C901
def _validate_operation( # noqa: C901, PLR0912 # pylint: disable=too-many-branches
operation: Operation,
) -> dict[StepName, type[BaseStep]]:
if len(operation.step_groups) == 0:
Expand Down Expand Up @@ -285,14 +286,27 @@ def _validate_operation( # noqa: C901
detected_steps_names[step_name] = step

for key in step.get_execute_provides_context_keys():
if key in ALL_RESERVED_CONTEXT_KEYS:
msg = (
f"Step {step_name=} provides {key=} which is part of reserved keys "
f"{ALL_RESERVED_CONTEXT_KEYS=}"
)
raise ValueError(msg)
if key in execute_provided_keys:
msg = (
f"Step {step_name=} provides already provided {key=} in "
f"{step.get_execute_provides_context_keys.__name__}()"
)
raise ValueError(msg)
execute_provided_keys.add(key)

for key in step.get_revert_provides_context_keys():
if key in ALL_RESERVED_CONTEXT_KEYS:
msg = (
f"Step {step_name=} provides {key=} which is part of reserved keys "
f"{ALL_RESERVED_CONTEXT_KEYS=}"
)
raise ValueError(msg)
if key in revert_provided_keys:
msg = (
f"Step {step_name=} provides already provided {key=} in "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
SingleStepGroup,
StepStoreProxy,
cancel_operation,
get_operation_name,
restart_operation_step_stuck_during_revert,
restart_operation_step_stuck_in_manual_intervention_during_execute,
start_operation,
Expand Down Expand Up @@ -1753,3 +1754,20 @@ async def test_step_does_not_provide_declared_key_or_is_none(

formatted_expected_keys = {k.format(schedule_id=schedule_id) for k in expected_keys}
await ensure_keys_in_store(selected_app, expected_keys=formatted_expected_keys)


@pytest.mark.parametrize("app_count", [10])
async def test_get_operation_name(
preserve_caplog_for_async_logging: None,
operation_name: OperationName,
selected_app: FastAPI,
register_operation: Callable[[OperationName, Operation], None],
):
assert await get_operation_name(selected_app, "non_existing_schedule_id") is None

operation = Operation(SingleStepGroup(_S1))
register_operation(operation_name, operation)

schedule_id = await start_operation(selected_app, operation_name, {})

assert await get_operation_name(selected_app, schedule_id) == operation_name
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
StepNotFoundInoperationError,
)
from simcore_service_dynamic_scheduler.services.generic_scheduler._models import (
ALL_RESERVED_CONTEXT_KEYS,
ProvidedOperationContext,
RequiredOperationContext,
ReservedContextKeys,
)
from simcore_service_dynamic_scheduler.services.generic_scheduler._operation import (
BaseStep,
Expand Down Expand Up @@ -57,6 +59,12 @@ def get_execute_provides_context_keys(cls) -> set[str]:
return {"execute_key"}


class WrongBS3C(BaseBS):
@classmethod
def get_execute_provides_context_keys(cls) -> set[str]:
return {ReservedContextKeys.SCHEDULE_ID}


class WrongBS1R(BaseBS):
@classmethod
def get_revert_provides_context_keys(cls) -> set[str]:
Expand All @@ -69,6 +77,22 @@ def get_revert_provides_context_keys(cls) -> set[str]:
return {"revert_key"}


class WrongBS3R(BaseBS):
@classmethod
def get_revert_provides_context_keys(cls) -> set[str]:
return {ReservedContextKeys.SCHEDULE_ID}


class AllowedKeysBS(BaseBS):
@classmethod
def get_execute_requires_context_keys(cls) -> set[str]:
return {ReservedContextKeys.SCHEDULE_ID}

@classmethod
def get_revert_requires_context_keys(cls) -> set[str]:
return {ReservedContextKeys.SCHEDULE_ID}


@pytest.mark.parametrize(
"operation",
[
Expand Down Expand Up @@ -104,6 +128,9 @@ def get_revert_provides_context_keys(cls) -> set[str]:
Operation(
ParallelStepGroup(BS1, BS3, repeat_steps=True),
),
Operation(
SingleStepGroup(AllowedKeysBS),
),
],
)
def test_validate_operation_passes(operation: Operation):
Expand Down Expand Up @@ -166,6 +193,14 @@ def test_validate_operation_passes(operation: Operation):
),
"cannot have steps that require manual intervention",
),
(
Operation(SingleStepGroup(WrongBS3C)),
"which is part of reserved keys ALL_RESERVED_CONTEXT_KEYS",
),
(
Operation(SingleStepGroup(WrongBS3R)),
"which is part of reserved keys ALL_RESERVED_CONTEXT_KEYS",
),
],
)
def test_validate_operations_fails(operation: Operation, match: str):
Expand Down Expand Up @@ -204,3 +239,9 @@ def test_operation_registry_raises_errors():

with pytest.raises(StepNotFoundInoperationError):
OperationRegistry.get_step("op1", "non_existing")


def test_reserved_context_keys_existence():
for e in ReservedContextKeys:
assert e.value in ALL_RESERVED_CONTEXT_KEYS
assert "missing_key" not in ALL_RESERVED_CONTEXT_KEYS
Original file line number Diff line number Diff line change
Expand Up @@ -387,9 +387,13 @@ async def test_can_recover_from_interruption(
await ensure_expected_order(queue_poller.events, expected_order)


_INITIAL_OP_NAME: OperationName = "initial"
_AFTER_OP_NAME: OperationName = "after"


@pytest.mark.parametrize("register_at_creation", [True, False])
@pytest.mark.parametrize(
"is_executing, initial_op, after_op, expected_order",
"is_executing, initial_op, after_op, expected_order, to_start",
[
pytest.param(
True,
Expand All @@ -399,6 +403,16 @@ async def test_can_recover_from_interruption(
ExecuteSequence(_ShortSleep),
ExecuteSequence(_S2),
],
OperationToStart(operation_name=_AFTER_OP_NAME, initial_context={}),
),
pytest.param(
True,
Operation(SingleStepGroup(_ShortSleep)),
None,
[
ExecuteSequence(_ShortSleep),
],
None,
),
pytest.param(
False,
Expand All @@ -409,6 +423,17 @@ async def test_can_recover_from_interruption(
RevertSequence(_ShortSleepThenRevert),
ExecuteSequence(_S2),
],
OperationToStart(operation_name=_AFTER_OP_NAME, initial_context={}),
),
pytest.param(
False,
Operation(SingleStepGroup(_ShortSleepThenRevert)),
None,
[
ExecuteSequence(_ShortSleepThenRevert),
RevertSequence(_ShortSleepThenRevert),
],
None,
),
],
)
Expand All @@ -420,33 +445,25 @@ async def test_run_operation_after(
register_at_creation: bool,
is_executing: bool,
initial_op: Operation,
after_op: Operation,
after_op: Operation | None,
expected_order: list[BaseExpectedStepOrder],
to_start: OperationToStart | None,
):
initial_op_name: OperationName = "initial"
after_op_name: OperationName = "after"

register_operation(initial_op_name, initial_op)
register_operation(after_op_name, after_op)
register_operation(_INITIAL_OP_NAME, initial_op)
if after_op is not None:
register_operation(_AFTER_OP_NAME, after_op)

if is_executing:
on_execute_completed = (
OperationToStart(operation_name=after_op_name, initial_context={})
if register_at_creation
else None
)
on_execute_completed = to_start if register_at_creation else None
on_revert_completed = None
else:
on_execute_completed = None
on_revert_completed = (
OperationToStart(operation_name=after_op_name, initial_context={})
if register_at_creation
else None
)
on_revert_completed = to_start if register_at_creation else None

schedule_id = await start_operation(
app,
initial_op_name,
_INITIAL_OP_NAME,
{},
on_execute_completed=on_execute_completed,
on_revert_completed=on_revert_completed,
Expand All @@ -455,19 +472,11 @@ async def test_run_operation_after(
if register_at_creation is False:
if is_executing:
await register_to_start_after_on_executed_completed(
app,
schedule_id,
to_start=OperationToStart(
operation_name=after_op_name, initial_context={}
),
app, schedule_id, to_start=to_start
)
else:
await register_to_start_after_on_reverted_completed(
app,
schedule_id,
to_start=OperationToStart(
operation_name=after_op_name, initial_context={}
),
app, schedule_id, to_start=to_start
)

await ensure_expected_order(steps_call_order, expected_order)
Expand Down
Loading