@@ -195,6 +195,7 @@ def act(self, goal: Goal, toolbox: Toolbox) -> Action:
195195 {"role" : "user" , "content" : goal .prompt },
196196 ]
197197
198+ request_count = 0
198199 while True :
199200 response = self ._client .chat .completions .create (
200201 model = self ._model ,
@@ -203,6 +204,7 @@ def act(self, goal: Goal, toolbox: Toolbox) -> Action:
203204 tool_choice = "required" ,
204205 )
205206 assert len (response .choices ) == 1
207+ request_count += 1
206208
207209 done = True
208210 calls = response .choices [0 ].message .tool_calls
@@ -214,7 +216,7 @@ def act(self, goal: Goal, toolbox: Toolbox) -> Action:
214216 if done :
215217 break
216218
217- return Action ()
219+ return Action (request_count = request_count )
218220
219221
220222class _CompletionsToolHandler (_ToolHandler [str | None ]):
@@ -262,41 +264,58 @@ def create(cls, client: openai.OpenAI, model: str) -> Self:
262264 return cls (client , assistant_id )
263265
264266 def act (self , goal : Goal , toolbox : Toolbox ) -> Action :
265- # TODO: Use timeout.
266267 thread = self ._client .beta .threads .create ()
267-
268268 self ._client .beta .threads .messages .create (
269269 thread_id = thread .id ,
270270 role = "user" ,
271271 content = goal .prompt ,
272272 )
273273
274+ # We intentionally do not count the two requests above, to focus on
275+ # "data requests" only.
276+ action = Action (request_count = 0 , token_count = 0 )
274277 with self ._client .beta .threads .runs .stream (
275278 thread_id = thread .id ,
276279 assistant_id = self ._assistant_id ,
277- event_handler = _EventHandler (self ._client , toolbox ),
280+ event_handler = _EventHandler (self ._client , toolbox , action ),
278281 ) as stream :
279282 stream .until_done ()
280-
281- return Action ()
283+ return action
282284
283285
284286class _EventHandler (openai .AssistantEventHandler ):
285- def __init__ (self , client : openai .Client , toolbox : Toolbox ) -> None :
287+ def __init__ (
288+ self , client : openai .Client , toolbox : Toolbox , action : Action
289+ ) -> None :
286290 super ().__init__ ()
287291 self ._client = client
288292 self ._toolbox = toolbox
293+ self ._action = action
294+ self ._action .increment_request_count ()
289295
290- def clone (self ) -> Self :
291- return self .__class__ (self ._client , self ._toolbox )
296+ def _clone (self ) -> Self :
297+ return self .__class__ (self ._client , self ._toolbox , self . _action )
292298
293299 @override
294- def on_event (self , event : Any ) -> None :
295- _logger .debug ("Event: %s" , event )
300+ def on_event (self , event : openai .types .beta .AssistantStreamEvent ) -> None :
296301 if event .event == "thread.run.requires_action" :
297302 run_id = event .data .id # Retrieve the run ID from the event data
298303 self ._handle_action (run_id , event .data )
299- # TODO: Handle (log?) other events.
304+ elif event .event == "thread.run.completed" :
305+ _logger .info ("Threads run completed. [usage=%s]" , event .data .usage )
306+ else :
307+ _logger .debug ("Threads event: %s" , event )
308+
309+ @override
310+ def on_run_step_done (
311+ self , run_step : openai .types .beta .threads .runs .RunStep
312+ ) -> None :
313+ usage = run_step .usage
314+ if usage :
315+ _logger .debug ("Threads run step usage: %s" , usage )
316+ self ._action .increment_token_count (usage .total_tokens )
317+ else :
318+ _logger .warning ("Missing usage in threads run step" )
300319
301320 def _handle_action (self , run_id : str , data : Any ) -> None :
302321 tool_outputs = list [Any ]()
@@ -310,7 +329,7 @@ def _handle_action(self, run_id: str, data: Any) -> None:
310329 thread_id = run .thread_id ,
311330 run_id = run .id ,
312331 tool_outputs = tool_outputs ,
313- event_handler = self .clone (),
332+ event_handler = self ._clone (),
314333 ) as stream :
315334 stream .until_done ()
316335
0 commit comments