@@ -691,6 +691,190 @@ async def parent_workflow() -> tuple[str, str, str]:
691691 assert async_child_status .parent_workflow_id == parent_id
692692
693693
694+ @pytest .mark .asyncio
695+ async def test_asyncio_wait (dbos : DBOS ) -> None :
696+ step_counter : int = 0
697+ gate = asyncio .Event ()
698+
699+ @DBOS .step ()
700+ async def fast_step (val : str ) -> str :
701+ nonlocal step_counter
702+ step_counter += 1
703+ return val + "_done"
704+
705+ @DBOS .step ()
706+ async def slow_step (val : str ) -> str :
707+ nonlocal step_counter
708+ step_counter += 1
709+ await gate .wait ()
710+ return val + "_done"
711+
712+ @DBOS .workflow ()
713+ async def wait_workflow () -> None :
714+ done , pending = await DBOS .asyncio_wait (
715+ [fast_step ("fast" ), slow_step ("slow" )],
716+ return_when = asyncio .FIRST_COMPLETED ,
717+ )
718+
719+ assert len (done ) == 1
720+ assert len (pending ) == 1
721+ assert [t .result () for t in done ] == ["fast_done" ]
722+
723+ # Let the slow step finish and wait for it
724+ gate .set ()
725+ done2 , pending2 = await DBOS .asyncio_wait (list (pending ))
726+ assert len (done2 ) == 1
727+ assert len (pending2 ) == 0
728+ assert [t .result () for t in done2 ] == ["slow_done" ]
729+
730+ handle = await DBOS .start_workflow_async (wait_workflow )
731+ await handle .get_result ()
732+ assert step_counter == 2
733+
734+ # Verify recorded steps
735+ steps = await DBOS .list_workflow_steps_async (handle .workflow_id )
736+ assert len (steps ) == 4
737+ # Step 1: first asyncio_wait snapshots its context before the tasks run.
738+ # Recorded done indices [0] means the first future (fast_step) completed.
739+ assert steps [0 ]["function_id" ] == 1
740+ assert steps [0 ]["function_name" ] == "DBOS.asyncio_wait"
741+ assert steps [0 ]["output" ] == [0 ]
742+ # Steps 2 & 3: the step coroutines execute inside the asyncio tasks
743+ assert steps [1 ]["function_id" ] == 2
744+ assert steps [1 ]["function_name" ] == fast_step .__qualname__
745+ assert steps [1 ]["output" ] == "fast_done"
746+ assert steps [2 ]["function_id" ] == 3
747+ assert steps [2 ]["function_name" ] == slow_step .__qualname__
748+ assert steps [2 ]["output" ] == "slow_done"
749+ # Step 4: second asyncio_wait on the pending set
750+ assert steps [3 ]["function_id" ] == 4
751+ assert steps [3 ]["function_name" ] == "DBOS.asyncio_wait"
752+ assert steps [3 ]["output" ] == [0 ]
753+
754+ # Fork from a high step to replay everything from DB (OAOO)
755+ forked = await DBOS .fork_workflow_async (handle .workflow_id , 100 )
756+ await forked .get_result ()
757+ assert step_counter == 2
758+
759+
760+ @pytest .mark .asyncio
761+ async def test_asyncio_wait_all_completed (dbos : DBOS ) -> None :
762+ step_counter : int = 0
763+
764+ @DBOS .step ()
765+ async def my_step (val : str ) -> str :
766+ nonlocal step_counter
767+ step_counter += 1
768+ return val
769+
770+ @DBOS .workflow ()
771+ async def wait_all_workflow () -> None :
772+ done , pending = await DBOS .asyncio_wait (
773+ [my_step ("a" ), my_step ("b" )], return_when = asyncio .ALL_COMPLETED
774+ )
775+
776+ assert len (done ) == 2
777+ assert len (pending ) == 0
778+ assert sorted ([t .result () for t in done ]) == ["a" , "b" ]
779+
780+ handle = await DBOS .start_workflow_async (wait_all_workflow )
781+ await handle .get_result ()
782+ assert step_counter == 2
783+
784+ # Fork from a high step to replay everything from DB (OAOO)
785+ forked = await DBOS .fork_workflow_async (handle .workflow_id , 100 )
786+ await forked .get_result ()
787+ assert step_counter == 2
788+
789+
790+ @pytest .mark .asyncio
791+ async def test_asyncio_wait_first_exception (dbos : DBOS ) -> None :
792+ step_counter : int = 0
793+ gate = asyncio .Event ()
794+
795+ @DBOS .step ()
796+ async def error_step () -> str :
797+ nonlocal step_counter
798+ step_counter += 1
799+ raise ValueError ("boom" )
800+
801+ @DBOS .step ()
802+ async def slow_step (val : str ) -> str :
803+ nonlocal step_counter
804+ step_counter += 1
805+ await gate .wait ()
806+ return val
807+
808+ @DBOS .workflow ()
809+ async def wait_exception_workflow () -> None :
810+ done , pending = await DBOS .asyncio_wait (
811+ [error_step (), slow_step ("ok" )],
812+ return_when = asyncio .FIRST_EXCEPTION ,
813+ )
814+
815+ assert len (done ) == 1
816+ assert len (pending ) == 1
817+ task = next (iter (done ))
818+ assert isinstance (task .exception (), ValueError )
819+ assert "boom" in str (task .exception ())
820+
821+ # Let the slow step finish and wait for it
822+ gate .set ()
823+ done2 , pending2 = await DBOS .asyncio_wait (list (pending ))
824+ assert len (done2 ) == 1
825+ assert len (pending2 ) == 0
826+ assert next (iter (done2 )).result () == "ok"
827+
828+ handle = await DBOS .start_workflow_async (wait_exception_workflow )
829+ await handle .get_result ()
830+ assert step_counter == 2
831+
832+ # Fork from a high step to replay everything from DB (OAOO)
833+ forked = await DBOS .fork_workflow_async (handle .workflow_id , 100 )
834+ await forked .get_result ()
835+ assert step_counter == 2
836+
837+
838+ @pytest .mark .asyncio
839+ async def test_asyncio_wait_timeout (dbos : DBOS ) -> None :
840+ step_counter : int = 0
841+ gate = asyncio .Event ()
842+
843+ @DBOS .step ()
844+ async def blocked_step (val : str ) -> str :
845+ nonlocal step_counter
846+ step_counter += 1
847+ await gate .wait ()
848+ return val
849+
850+ @DBOS .workflow ()
851+ async def wait_timeout_workflow () -> None :
852+ done , pending = await DBOS .asyncio_wait (
853+ [blocked_step ("a" ), blocked_step ("b" )],
854+ timeout = 0.1 ,
855+ )
856+
857+ # Both should still be pending after timeout
858+ assert len (done ) == 0
859+ assert len (pending ) == 2
860+
861+ # Unblock and wait for all
862+ gate .set ()
863+ done2 , pending2 = await DBOS .asyncio_wait (list (pending ))
864+ assert len (done2 ) == 2
865+ assert len (pending2 ) == 0
866+ assert sorted ([t .result () for t in done2 ]) == ["a" , "b" ]
867+
868+ handle = await DBOS .start_workflow_async (wait_timeout_workflow )
869+ await handle .get_result ()
870+ assert step_counter == 2
871+
872+ # Fork from a high step to replay everything from DB (OAOO)
873+ forked = await DBOS .fork_workflow_async (handle .workflow_id , 100 )
874+ await forked .get_result ()
875+ assert step_counter == 2
876+
877+
694878@pytest .mark .asyncio
695879async def test_workflow_recovery_async (dbos : DBOS , config : DBOSConfig ) -> None :
696880 DBOS .destroy (destroy_registry = True )
0 commit comments