@@ -610,6 +610,26 @@ async def model_version_async(self):
610610 return 0
611611
612612
613+ class APIWorkflow (Workflow ):
614+ is_async : bool = True
615+
616+ def __init__ (self , model : ModelWrapper , task : Task , auxiliary_models = None ):
617+ super ().__init__ (task = task , model = model , auxiliary_models = auxiliary_models )
618+ self .client = model .get_openai_async_client ()
619+ self .raise_except = task .raw_task .get ("raise_except" , False )
620+
621+ async def run_async (self ):
622+ _ = await self .client .chat .completions .create (
623+ model = self .client .model_path ,
624+ messages = [{"role" : "user" , "content" : "Hello!" }],
625+ )
626+ if self .raise_except :
627+ raise RuntimeError ("Intentional Exception for testing." )
628+ exps = self .model .extract_experience_from_history ()
629+ exps [0 ].reward = 0.5
630+ return exps
631+
632+
613633class TestWorkflowRunner (unittest .IsolatedAsyncioTestCase ):
614634 async def test_workflow_runner (self ):
615635 config = get_template_config ()
@@ -697,3 +717,46 @@ async def monitor_routine():
697717 await asyncio .gather (
698718 * [monitor_routine (), runner .run_task (task , repeat_times = 3 , run_id_base = 0 )]
699719 )
720+
721+ async def test_workflow_with_openai (self ):
722+ config = get_template_config ()
723+ config .mode = "explore"
724+ config .model .model_path = get_model_path ()
725+ config .explorer .rollout_model .engine_num = 1
726+ config .explorer .rollout_model .enable_openai_api = True
727+ config .explorer .rollout_model .enable_history = True
728+ config .check_and_update ()
729+ engines , auxiliary_engines = create_inference_models (config )
730+
731+ runner = WorkflowRunner (
732+ config ,
733+ model = engines [0 ],
734+ auxiliary_models = [],
735+ runner_id = 0 ,
736+ )
737+ await runner .prepare ()
738+ tasks = [
739+ Task (
740+ workflow = APIWorkflow ,
741+ raw_task = {"raise_except" : True },
742+ repeat_times = 2 ,
743+ ),
744+ Task (
745+ workflow = APIWorkflow ,
746+ raw_task = {},
747+ repeat_times = 2 ,
748+ ),
749+ ]
750+
751+ status , exps = await runner .run_task (
752+ tasks [0 ], repeat_times = 2 , run_id_base = 0
753+ ) # test exception handling
754+ self .assertEqual (status .ok , False )
755+ self .assertEqual (len (exps ), 0 )
756+ exps = runner .model_wrapper .extract_experience_from_history (clear_history = False )
757+ self .assertEqual (len (exps ), 1 )
758+ status , exps = await runner .run_task (tasks [1 ], repeat_times = 2 , run_id_base = 0 ) # normal run
759+ self .assertEqual (status .ok , True )
760+ self .assertEqual (len (exps ), 2 )
761+ exps = runner .model_wrapper .extract_experience_from_history (clear_history = False )
762+ self .assertEqual (len (exps ), 0 )
0 commit comments