@@ -90,6 +90,60 @@ def retry(
9090 raise ParseError (f"Could not parse a valid value after { n_retry } retries." )
9191
9292
93+ def retry_multiple (
94+ chat : "ChatModel" ,
95+ messages : "Discussion" ,
96+ n_retry : int ,
97+ parser : callable ,
98+ log : bool = True ,
99+ num_samples : int = 1 ,
100+ ):
101+ """Retry querying the chat models with the response from the parser until it
102+ returns a valid value.
103+
104+ If the answer is not valid, it will retry and append to the chat the retry
105+ message. It will stop after `n_retry`.
106+
107+ Note, each retry has to resend the whole prompt to the API. This can be slow
108+ and expensive.
109+
110+ Args:
111+ chat (ChatModel): a ChatModel object taking a list of messages and
112+ returning a list of answers, all in OpenAI format.
113+ messages (list): the list of messages so far. This list will be modified with
114+ the new messages and the retry messages.
115+ n_retry (int): the maximum number of sequential retries.
116+ parser (callable): a function taking a message and retruning a parsed value,
117+ or raising a ParseError
118+ log (bool): whether to log the retry messages.
119+
120+ Returns:
121+ dict: the parsed value, with a string at key "action".
122+
123+ Raises:
124+ ParseError: if the parser could not parse the response after n_retry retries.
125+ """
126+ tries = 0
127+ while tries < n_retry :
128+ answer_list = chat (messages , num_samples = num_samples )
129+ # TODO: could we change this to not use inplace modifications ?
130+ messages .append (answer )
131+ parsed_answers = []
132+ errors = []
133+ for answer in answer_list :
134+ try :
135+ parsed_answers .append (parser (answer ["content" ]))
136+ except ParseError as parsing_error :
137+ errors .append (str (parsing_error ))
138+ tries += 1
139+ if log :
140+ msg = f"Query failed. Retrying { tries } /{ n_retry } .\n [LLM]:\n { answer ['content' ]} \n [User]:\n { str (errors )} "
141+ logging .info (msg )
142+ messages .append (dict (role = "user" , content = str (errors )))
143+
144+ raise ParseError (f"Could not parse a valid value after { n_retry } retries." )
145+
146+
93147def truncate_tokens (text , max_tokens = 8000 , start = 0 , model_name = "gpt-4" ):
94148 """Use tiktoken to truncate a text to a maximum number of tokens."""
95149 enc = tiktoken .encoding_for_model (model_name )
0 commit comments