Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -585,20 +585,19 @@ async def _fs_handle_manually_cancelled( # pylint:disable=method-hidden
self, task_uid: TaskUID
) -> None:
_log_state(TaskState.MANUALLY_CANCELLED, task_uid)
_logger.info("Attempting to cancel task_uid '%s'", task_uid)
_logger.info("Recevied a cancel request for task_uid '%s'", task_uid)

task_schedule = await self.__get_task_schedule(
task_uid, expected_state=TaskState.MANUALLY_CANCELLED
)

if task_schedule.state == TaskState.WORKER:
run_was_cancelled = self._worker_tracker.cancel_run(task_uid)
if not run_was_cancelled:
_logger.debug(
"Currently not handling task related to '%s'. Did not cancel it.",
task_uid,
)
return
run_was_cancelled = self._worker_tracker.cancel_run(task_uid)
if not run_was_cancelled:
_logger.debug(
"Currently not handling task related to '%s'. Did not cancel it.",
task_uid,
)
return

_logger.info("Found and cancelled run for '%s'", task_uid)
await self.__remove_task(task_uid, task_schedule)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,14 @@ async def cancel_operation(self, schedule_id: ScheduleId) -> None:
if operation.is_cancellable is False:
raise OperationNotCancellableError(operation_name=operation_name)

group = operation.step_groups[group_index]
step_group = operation.step_groups[group_index]

group_step_proxies = get_group_step_proxies(
self._store,
schedule_id=schedule_id,
operation_name=operation_name,
group_index=group_index,
step_group=group,
step_group=step_group,
is_executing=is_executing,
)

Expand All @@ -204,6 +204,8 @@ async def cancel_operation(self, schedule_id: ScheduleId) -> None:
schedule_id=schedule_id
)

expected_steps_count = len(step_group)

async def _cancel_step(step_name: StepName, step_proxy: StepStoreProxy) -> None:
with log_context( # noqa: SIM117
_logger,
Expand All @@ -212,8 +214,25 @@ async def _cancel_step(step_name: StepName, step_proxy: StepStoreProxy) -> None:
):
with suppress(NoDataFoundError):
deferred_task_uid = await step_proxy.read("deferred_task_uid")
# the deferred task might not be running when this is called
# e.g. cancelling a repeating operation
await DeferredRunner.cancel(deferred_task_uid)
await step_proxy.create_or_update("status", StepStatus.CANCELLED)

await step_proxy.create_or_update("status", StepStatus.CANCELLED)

step_group_name = step_group.get_step_group_name(index=group_index)
group_proxy = StepGroupProxy(
store=self._store,
schedule_id=schedule_id,
operation_name=operation_name,
step_group_name=step_group_name,
is_executing=is_executing,
)
if (
await group_proxy.increment_and_get_done_steps_count()
== expected_steps_count
):
await enqueue_schedule_event(self.app, schedule_id)

await limited_gather(
*(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,3 @@ async def on_finished_with_error(
)

await _enqueue_schedule_event_if_group_is_done(context)

@classmethod
async def on_cancelled(cls, context: DeferredContext) -> None:
await get_step_store_proxy(context).create_or_update(
"status", StepStatus.CANCELLED
)

await _enqueue_schedule_event_if_group_is_done(context)
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,43 @@ async def _esnure_steps_have_status(
raise AssertionError(msg) from None


async def _ensure_one_step_in_manual_intervention(
app: FastAPI,
schedule_id: ScheduleId,
operation_name: OperationName,
*,
step_group_name: StepGroupName,
steps: Iterable[type[BaseStep]],
) -> None:
store_proxies = [
StepStoreProxy(
store=Store.get_from_app_state(app),
schedule_id=schedule_id,
operation_name=operation_name,
step_group_name=step_group_name,
step_name=step.get_step_name(),
is_executing=True,
)
for step in steps
]

async for attempt in AsyncRetrying(**_RETRY_PARAMS):
with attempt:
reuires_intervention = False
for proxy in store_proxies:
try:
requires_manual_intervention = await proxy.read(
"requires_manual_intervention"
)
if requires_manual_intervention:
reuires_intervention = True
break
except NoDataFoundError:
pass

assert reuires_intervention is True


############## TESTS ##############


Expand Down Expand Up @@ -853,7 +890,7 @@ async def test_fails_during_revert_is_in_error_state(
RevertRandom(_S2, _S3, _S4),
RevertSequence(_S1),
],
id="s1p3s1(1s)",
id="s1p3s1(1sf)",
),
pytest.param(
Operation(
Expand All @@ -870,7 +907,7 @@ async def test_fails_during_revert_is_in_error_state(
RevertRandom(_S2, _S3, _S4, _SF2, _SF1),
RevertSequence(_S1),
],
id="s1p4(1s)",
id="s1p5(2sf)",
),
],
)
Expand Down Expand Up @@ -1105,21 +1142,28 @@ async def test_wait_for_manual_intervention(

await ensure_keys_in_store(selected_app, expected_keys=formatted_expected_keys)

group_index = len(expected_order) - 1
step_group_name = operation.step_groups[group_index].get_step_group_name(
index=group_index
)
await _esnure_steps_have_status(
selected_app,
schedule_id,
operation_name,
step_group_name=operation.step_groups[
len(expected_order) - 1
].get_step_group_name(index=len(expected_order) - 1),
step_group_name=step_group_name,
steps=expected_order[-1].steps,
)

# even if cancelled, state of waiting for manual intervention remains the same
async for attempt in AsyncRetrying(**_RETRY_PARAMS):
with attempt: # noqa: SIM117
with pytest.raises(CannotCancelWhileWaitingForManualInterventionError):
await cancel_operation(selected_app, schedule_id)
await _ensure_one_step_in_manual_intervention(
selected_app,
schedule_id,
operation_name,
step_group_name=step_group_name,
steps=expected_order[-1].steps,
)
with pytest.raises(CannotCancelWhileWaitingForManualInterventionError):
await cancel_operation(selected_app, schedule_id)

await ensure_keys_in_store(selected_app, expected_keys=formatted_expected_keys)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ async def execute(
_ = app
_ = required_context
_StepResultStore.set_result(cls.__name__, "executed")
await asyncio.sleep(10000)
await asyncio.sleep(1e6)
return {}

@classmethod
Expand All @@ -194,7 +194,7 @@ async def revert(
_ = app
_ = required_context
_StepResultStore.set_result(cls.__name__, "destroyed")
await asyncio.sleep(10000)
await asyncio.sleep(1e6)
return {}


Expand Down Expand Up @@ -240,14 +240,14 @@ def _get_step_group(
Operation(
SingleStepGroup(_StepLongRunningToCancel),
),
StepStatus.CANCELLED,
StepStatus.RUNNING,
_Action.CANCEL,
1,
),
],
)
@pytest.mark.parametrize("is_executing", [True, False])
async def test_something(
async def test_workflow(
mock_enqueue_event: AsyncMock,
registed_operation: None,
app: FastAPI,
Expand Down Expand Up @@ -304,7 +304,7 @@ async def test_something(
await asyncio.sleep(0.2) # give it some time to start

task_uid = await step_proxy.read("deferred_task_uid")
await DeferredRunner.cancel(task_uid)
await asyncio.create_task(DeferredRunner.cancel(task_uid))

await _assert_finshed_with_status(step_proxy, expected_step_status)

Expand All @@ -317,4 +317,9 @@ async def test_something(
assert "I failed" in error_traceback

# ensure called once with arguments
assert mock_enqueue_event.call_args_list == [((app, schedule_id),)]

assert (
mock_enqueue_event.call_args_list == []
if action == _Action.CANCEL
else [((app, schedule_id),)]
)
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ async def execute(
{"key": "value", "dict": {"some": "thing"}, "list": [1, 2, 3]},
],
)
async def test_something(
async def test_workflow(
after_event_manager: AfterEventManager,
store: Store,
schedule_id: ScheduleId,
Expand Down
Loading