diff --git a/tests/explorer/workflow_test.py b/tests/explorer/workflow_test.py index e0a86da6f1..31b232e148 100644 --- a/tests/explorer/workflow_test.py +++ b/tests/explorer/workflow_test.py @@ -610,6 +610,26 @@ async def model_version_async(self): return 0 +class APIWorkflow(Workflow): + is_async: bool = True + + def __init__(self, model: ModelWrapper, task: Task, auxiliary_models=None): + super().__init__(task=task, model=model, auxiliary_models=auxiliary_models) + self.client = model.get_openai_async_client() + self.raise_except = task.raw_task.get("raise_except", False) + + async def run_async(self): + _ = await self.client.chat.completions.create( + model=self.client.model_path, + messages=[{"role": "user", "content": "Hello!"}], + ) + if self.raise_except: + raise RuntimeError("Intentional Exception for testing.") + exps = self.model.extract_experience_from_history() + exps[0].reward = 0.5 + return exps + + class TestWorkflowRunner(unittest.IsolatedAsyncioTestCase): async def test_workflow_runner(self): config = get_template_config() @@ -697,3 +717,46 @@ async def monitor_routine(): await asyncio.gather( *[monitor_routine(), runner.run_task(task, repeat_times=3, run_id_base=0)] ) + + async def test_workflow_with_openai(self): + config = get_template_config() + config.mode = "explore" + config.model.model_path = get_model_path() + config.explorer.rollout_model.engine_num = 1 + config.explorer.rollout_model.enable_openai_api = True + config.explorer.rollout_model.enable_history = True + config.check_and_update() + engines, auxiliary_engines = create_inference_models(config) + + runner = WorkflowRunner( + config, + model=engines[0], + auxiliary_models=[], + runner_id=0, + ) + await runner.prepare() + tasks = [ + Task( + workflow=APIWorkflow, + raw_task={"raise_except": True}, + repeat_times=2, + ), + Task( + workflow=APIWorkflow, + raw_task={}, + repeat_times=2, + ), + ] + + status, exps = await runner.run_task( + tasks[0], repeat_times=2, run_id_base=0 + ) # test exception handling + self.assertEqual(status.ok, False) + self.assertEqual(len(exps), 0) + exps = runner.model_wrapper.extract_experience_from_history(clear_history=False) + self.assertEqual(len(exps), 1) + status, exps = await runner.run_task(tasks[1], repeat_times=2, run_id_base=0) # normal run + self.assertEqual(status.ok, True) + self.assertEqual(len(exps), 2) + exps = runner.model_wrapper.extract_experience_from_history(clear_history=False) + self.assertEqual(len(exps), 0) diff --git a/trinity/common/models/model.py b/trinity/common/models/model.py index b4f9fd207c..483915b893 100644 --- a/trinity/common/models/model.py +++ b/trinity/common/models/model.py @@ -376,6 +376,7 @@ async def clean_workflow_state(self) -> None: """Clean the state of workflow using the model.""" async with self.state_lock: self.workflow_state = {} + self.history.clear() async def get_workflow_state(self) -> Dict: """Get the state of workflow using the model."""