Skip to content

Commit 731786d

Browse files
committed
looking good
1 parent 7e49aa7 commit 731786d

File tree

2 files changed

+60
-2
lines changed

2 files changed

+60
-2
lines changed

src/agentlab/llm/chat_api.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def __init__(
261261
**client_args,
262262
)
263263

264-
def __call__(self, messages: list[dict]) -> dict:
264+
def __call__(self, messages: list[dict], n_samples: int = 1) -> dict:
265265
# Initialize retry tracking attributes
266266
self.retries = 0
267267
self.success = False
@@ -275,6 +275,7 @@ def __call__(self, messages: list[dict]) -> dict:
275275
completion = self.client.chat.completions.create(
276276
model=self.model_name,
277277
messages=messages,
278+
n=n_samples,
278279
temperature=self.temperature,
279280
max_tokens=self.max_tokens,
280281
)
@@ -305,7 +306,10 @@ def __call__(self, messages: list[dict]) -> dict:
305306
):
306307
tracking.TRACKER.instance(input_tokens, output_tokens, cost)
307308

308-
return AIMessage(completion.choices[0].message.content)
309+
if n_samples == 1:
310+
return AIMessage(completion.choices[0].message.content)
311+
else:
312+
return [AIMessage(c.message.content) for c in completion.choices]
309313

310314
def get_stats(self):
311315
return {

src/agentlab/llm/llm_utils.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,60 @@ def retry(
9090
raise ParseError(f"Could not parse a valid value after {n_retry} retries.")
9191

9292

93+
def retry_multiple(
94+
chat: "ChatModel",
95+
messages: "Discussion",
96+
n_retry: int,
97+
parser: callable,
98+
log: bool = True,
99+
num_samples: int = 1,
100+
):
101+
"""Retry querying the chat models with the response from the parser until it
102+
returns a valid value.
103+
104+
If the answer is not valid, it will retry and append to the chat the retry
105+
message. It will stop after `n_retry`.
106+
107+
Note, each retry has to resend the whole prompt to the API. This can be slow
108+
and expensive.
109+
110+
Args:
111+
chat (ChatModel): a ChatModel object taking a list of messages and
112+
returning a list of answers, all in OpenAI format.
113+
messages (list): the list of messages so far. This list will be modified with
114+
the new messages and the retry messages.
115+
n_retry (int): the maximum number of sequential retries.
116+
parser (callable): a function taking a message and retruning a parsed value,
117+
or raising a ParseError
118+
log (bool): whether to log the retry messages.
119+
120+
Returns:
121+
dict: the parsed value, with a string at key "action".
122+
123+
Raises:
124+
ParseError: if the parser could not parse the response after n_retry retries.
125+
"""
126+
tries = 0
127+
while tries < n_retry:
128+
answer_list = chat(messages, num_samples=num_samples)
129+
# TODO: could we change this to not use inplace modifications ?
130+
messages.append(answer)
131+
parsed_answers = []
132+
errors = []
133+
for answer in answer_list:
134+
try:
135+
parsed_answers.append(parser(answer["content"]))
136+
except ParseError as parsing_error:
137+
errors.append(str(parsing_error))
138+
tries += 1
139+
if log:
140+
msg = f"Query failed. Retrying {tries}/{n_retry}.\n[LLM]:\n{answer['content']}\n[User]:\n{str(errors)}"
141+
logging.info(msg)
142+
messages.append(dict(role="user", content=str(errors)))
143+
144+
raise ParseError(f"Could not parse a valid value after {n_retry} retries.")
145+
146+
93147
def truncate_tokens(text, max_tokens=8000, start=0, model_name="gpt-4"):
94148
"""Use tiktoken to truncate a text to a maximum number of tokens."""
95149
enc = tiktoken.encoding_for_model(model_name)

0 commit comments

Comments
 (0)