Skip to content

Commit e95c923

Browse files
authored
Fix openai api history (#428)
1 parent 80c3b1f commit e95c923

File tree

2 files changed

+64
-0
lines changed

2 files changed

+64
-0
lines changed

tests/explorer/workflow_test.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,26 @@ async def model_version_async(self):
610610
return 0
611611

612612

613+
class APIWorkflow(Workflow):
614+
is_async: bool = True
615+
616+
def __init__(self, model: ModelWrapper, task: Task, auxiliary_models=None):
617+
super().__init__(task=task, model=model, auxiliary_models=auxiliary_models)
618+
self.client = model.get_openai_async_client()
619+
self.raise_except = task.raw_task.get("raise_except", False)
620+
621+
async def run_async(self):
622+
_ = await self.client.chat.completions.create(
623+
model=self.client.model_path,
624+
messages=[{"role": "user", "content": "Hello!"}],
625+
)
626+
if self.raise_except:
627+
raise RuntimeError("Intentional Exception for testing.")
628+
exps = self.model.extract_experience_from_history()
629+
exps[0].reward = 0.5
630+
return exps
631+
632+
613633
class TestWorkflowRunner(unittest.IsolatedAsyncioTestCase):
614634
async def test_workflow_runner(self):
615635
config = get_template_config()
@@ -697,3 +717,46 @@ async def monitor_routine():
697717
await asyncio.gather(
698718
*[monitor_routine(), runner.run_task(task, repeat_times=3, run_id_base=0)]
699719
)
720+
721+
async def test_workflow_with_openai(self):
722+
config = get_template_config()
723+
config.mode = "explore"
724+
config.model.model_path = get_model_path()
725+
config.explorer.rollout_model.engine_num = 1
726+
config.explorer.rollout_model.enable_openai_api = True
727+
config.explorer.rollout_model.enable_history = True
728+
config.check_and_update()
729+
engines, auxiliary_engines = create_inference_models(config)
730+
731+
runner = WorkflowRunner(
732+
config,
733+
model=engines[0],
734+
auxiliary_models=[],
735+
runner_id=0,
736+
)
737+
await runner.prepare()
738+
tasks = [
739+
Task(
740+
workflow=APIWorkflow,
741+
raw_task={"raise_except": True},
742+
repeat_times=2,
743+
),
744+
Task(
745+
workflow=APIWorkflow,
746+
raw_task={},
747+
repeat_times=2,
748+
),
749+
]
750+
751+
status, exps = await runner.run_task(
752+
tasks[0], repeat_times=2, run_id_base=0
753+
) # test exception handling
754+
self.assertEqual(status.ok, False)
755+
self.assertEqual(len(exps), 0)
756+
exps = runner.model_wrapper.extract_experience_from_history(clear_history=False)
757+
self.assertEqual(len(exps), 1)
758+
status, exps = await runner.run_task(tasks[1], repeat_times=2, run_id_base=0) # normal run
759+
self.assertEqual(status.ok, True)
760+
self.assertEqual(len(exps), 2)
761+
exps = runner.model_wrapper.extract_experience_from_history(clear_history=False)
762+
self.assertEqual(len(exps), 0)

trinity/common/models/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ async def clean_workflow_state(self) -> None:
376376
"""Clean the state of workflow using the model."""
377377
async with self.state_lock:
378378
self.workflow_state = {}
379+
self.history.clear()
379380

380381
async def get_workflow_state(self) -> Dict:
381382
"""Get the state of workflow using the model."""

0 commit comments

Comments
 (0)