Skip to content

Commit aaec145

Browse files
committed
bug fix in email search
1 parent 42d10b4 commit aaec145

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

trinity/common/workflows/envs/email_searcher/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def read_email_tool(message_id: str) -> Optional[Email]:
288288
############ LLM-as-a-judge ############
289289

290290

291-
def judge_correctness(
291+
async def judge_correctness(
292292
answer: str,
293293
query: QueryModel,
294294
judger: Any,
@@ -318,7 +318,7 @@ def judge_correctness(
318318
{"role": "system", "content": system_prompt},
319319
{"role": "user", "content": prompt},
320320
]
321-
completion = judger.chat.completions.create(
321+
completion = await judger.chat.completions.create(
322322
model=judger.model_path, messages=messages, stream=False
323323
)
324324
result = completion.choices[0].message.content

trinity/common/workflows/envs/email_searcher/workflow.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ async def run_async(self):
9292
experiences
9393
) # NOTE: this metrics works only if the agent calls model once in each turn
9494

95-
reward_dict = self.calculate_reward(answer_and_sources)
95+
reward_dict = await self.calculate_reward(answer_and_sources)
9696
reward = sum(reward_dict.values())
9797

9898
for i, experience in enumerate(experiences):
@@ -107,7 +107,7 @@ async def run_async(self):
107107
)
108108
return experiences
109109

110-
def calculate_reward(self, answer_and_sources: Dict) -> Dict[str, float]:
110+
async def calculate_reward(self, answer_and_sources: Dict) -> Dict[str, float]:
111111
"""Ref: calculate_reward in https://github.com/OpenPipe/ART/blob/main/dev/art-e/art_e/rollout.py#L64"""
112112
try:
113113
answer = answer_and_sources.get("answer", None)
@@ -140,7 +140,7 @@ def calculate_reward(self, answer_and_sources: Dict) -> Dict[str, float]:
140140

141141
try:
142142
judge_model = self.auxiliary_models[0] if self.auxiliary_models else None
143-
judge_response = judge_correctness(answer, self.query, judge_model)
143+
judge_response = await judge_correctness(answer, self.query, judge_model)
144144
rubric.answer_correct = judge_response
145145

146146
except Exception as e:
@@ -179,4 +179,4 @@ def calculate_reward(self, answer_and_sources: Dict) -> Dict[str, float]:
179179
return result
180180

181181
self.logger.error(f"Rubric {rubric} not handled properly")
182-
raise ValueError("Rubric is not handled properly")
182+
return {"accuracy": 0.0, "format": 0.0}

0 commit comments

Comments
 (0)