@@ -615,11 +615,17 @@ def test_step(i: int) -> int:
615615 original_handle = DBOS .start_workflow (test_workflow )
616616 for e in step_events :
617617 e .wait ()
618+ e .clear ()
619+
618620 assert step_counter == 5
619621
620622 # Recover the workflow, then resume it.
621623 recovery_handles = DBOS .recover_pending_workflows ()
624+ # Wait until the 2nd invocation of the workflows are dequeued and executed
625+ for e in step_events :
626+ e .wait ()
622627 event .set ()
628+
623629 # There should be one handle for the workflow and another for each queued step.
624630 assert len (recovery_handles ) == queued_steps + 1
625631 # Verify that both the recovered and original workflows complete correctly.
@@ -639,6 +645,84 @@ def test_step(i: int) -> int:
639645 assert queue_entries_are_cleaned_up (dbos )
640646
641647
648+ def test_queue_concurrency_under_recovery (dbos : DBOS ) -> None :
649+ event = threading .Event ()
650+ wf_events = [threading .Event () for _ in range (2 )]
651+ counter = 0
652+
653+ @DBOS .workflow ()
654+ def blocked_workflow (i : int ) -> None :
655+ wf_events [i ].set ()
656+ nonlocal counter
657+ counter += 1
658+ event .wait ()
659+
660+ @DBOS .workflow ()
661+ def noop () -> None :
662+ pass
663+
664+ queue = Queue ("test_queue" , concurrency = 2 )
665+ handle1 = queue .enqueue (blocked_workflow , 0 )
666+ handle2 = queue .enqueue (blocked_workflow , 1 )
667+ handle3 = queue .enqueue (noop )
668+
669+ # Wait for the two first workflows to be dequeued
670+ for e in wf_events :
671+ e .wait ()
672+ e .clear ()
673+
674+ assert counter == 2
675+ assert handle1 .get_status ().status == WorkflowStatusString .PENDING .value
676+ assert handle2 .get_status ().status == WorkflowStatusString .PENDING .value
677+ assert handle3 .get_status ().status == WorkflowStatusString .ENQUEUED .value
678+
679+ # Manually update the database to pretend the 3rd workflow is PENDING and comes from another executor
680+ with dbos ._sys_db .engine .begin () as c :
681+ query = (
682+ sa .update (SystemSchema .workflow_status )
683+ .values (status = WorkflowStatusString .PENDING .value , executor_id = "other" )
684+ .where (
685+ SystemSchema .workflow_status .c .workflow_uuid
686+ == handle3 .get_workflow_id ()
687+ )
688+ )
689+ c .execute (query )
690+
691+ # Trigger workflow recovery. The two first workflows should still be blocked but the 3rd one enqueued
692+ recovered_other_handles = DBOS .recover_pending_workflows (["other" ])
693+ assert handle1 .get_status ().status == WorkflowStatusString .PENDING .value
694+ assert handle2 .get_status ().status == WorkflowStatusString .PENDING .value
695+ assert len (recovered_other_handles ) == 1
696+ assert recovered_other_handles [0 ].get_workflow_id () == handle3 .get_workflow_id ()
697+ assert handle3 .get_status ().status == WorkflowStatusString .ENQUEUED .value
698+
699+ # Trigger workflow recovery for "local". The two first workflows should be re-enqueued then dequeued again
700+ recovered_local_handles = DBOS .recover_pending_workflows (["local" ])
701+ assert len (recovered_local_handles ) == 2
702+ for h in recovered_local_handles :
703+ assert h .get_workflow_id () in [
704+ handle1 .get_workflow_id (),
705+ handle2 .get_workflow_id (),
706+ ]
707+ for e in wf_events :
708+ e .wait ()
709+ assert counter == 4
710+ assert handle1 .get_status ().status == WorkflowStatusString .PENDING .value
711+ assert handle2 .get_status ().status == WorkflowStatusString .PENDING .value
712+ # Because tasks are re-enqueued in order, the 3rd task is head of line blocked
713+ assert handle3 .get_status ().status == WorkflowStatusString .ENQUEUED .value
714+
715+ # Unblock the first two workflows
716+ event .set ()
717+
718+ # Verify all queue entries eventually get cleaned up.
719+ assert handle1 .get_result () == None
720+ assert handle2 .get_result () == None
721+ assert handle3 .get_result () == None
722+ assert handle3 .get_status ().executor_id == "local"
723+ assert queue_entries_are_cleaned_up (dbos )
724+
725+
642726def test_cancelling_queued_workflows (dbos : DBOS ) -> None :
643727 start_event = threading .Event ()
644728 blocking_event = threading .Event ()
@@ -746,17 +830,28 @@ def regular_workflow() -> None:
746830
747831 # Attempt to recover the blocked workflow the maximum number of times
748832 for i in range (max_recovery_attempts ):
833+ start_event .clear ()
749834 DBOS .recover_pending_workflows ()
835+ start_event .wait ()
750836 assert recovery_count == i + 2
751837
752- # Verify an additional recovery throws a DLQ error and puts the workflow in the DLQ status.
753- with pytest . raises ( Exception ) as exc_info :
754- DBOS . recover_pending_workflows ()
755- assert exc_info . errisinstance ( DBOSDeadLetterQueueError )
838+ # Verify an additional recovery throws puts the workflow in the DLQ status.
839+ DBOS . recover_pending_workflows ()
840+ # we can't start_event.wait() here because the workflow will never execute
841+ time . sleep ( 2 )
756842 assert (
757843 blocked_handle .get_status ().status
758844 == WorkflowStatusString .RETRIES_EXCEEDED .value
759845 )
846+ with dbos ._sys_db .engine .begin () as c :
847+ query = sa .select (SystemSchema .workflow_status .c .recovery_attempts ).where (
848+ SystemSchema .workflow_status .c .workflow_uuid
849+ == blocked_handle .get_workflow_id ()
850+ )
851+ result = c .execute (query )
852+ row = result .fetchone ()
853+ assert row is not None
854+ assert row [0 ] == max_recovery_attempts + 2
760855
761856 # Verify the blocked workflow entering the DLQ lets the regular workflow run
762857 assert regular_handle .get_result () == None
@@ -766,6 +861,15 @@ def regular_workflow() -> None:
766861 assert blocked_handle .get_result () == None
767862 dbos ._sys_db .wait_for_buffer_flush ()
768863 assert blocked_handle .get_status ().status == WorkflowStatusString .SUCCESS .value
864+ with dbos ._sys_db .engine .begin () as c :
865+ query = sa .select (SystemSchema .workflow_status .c .recovery_attempts ).where (
866+ SystemSchema .workflow_status .c .workflow_uuid
867+ == blocked_handle .get_workflow_id ()
868+ )
869+ result = c .execute (query )
870+ row = result .fetchone ()
871+ assert row is not None
872+ assert row [0 ] == max_recovery_attempts + 2
769873
770874 # Verify all queue entries eventually get cleaned up.
771875 assert queue_entries_are_cleaned_up (dbos )
0 commit comments