@@ -33,14 +33,21 @@ def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
3333 if messages [0 ]["role" ] == "system" :
3434 kwargs ["messages" ] = messages [1 :]
3535 kwargs ["system" ] = messages [0 ]["content" ] # set system prompt here
36+ if self .config .reasoning :
37+ kwargs ["thinking" ] = {"type" : "enabled" , "budget_tokens" : self .config .reasoning_tokens }
3638 return kwargs
3739
3840 def _update_costs (self , usage : Usage , model : str = None , local_calc_usage : bool = True ):
3941 usage = {"prompt_tokens" : usage .input_tokens , "completion_tokens" : usage .output_tokens }
4042 super ()._update_costs (usage , model )
4143
4244 def get_choice_text (self , resp : Message ) -> str :
43- return resp .content [0 ].text
45+ if len (resp .content ) > 0 :
46+ self .reasoning_content = resp .content [0 ].thinking
47+ text = resp .content [1 ].text
48+ else :
49+ text = resp .content [0 ].text
50+ return text
4451
4552 async def _achat_completion (self , messages : list [dict ], timeout : int = USE_CONFIG_TIMEOUT ) -> Message :
4653 resp : Message = await self .aclient .messages .create (** self ._const_kwargs (messages ))
@@ -53,20 +60,27 @@ async def acompletion(self, messages: list[dict], timeout: int = USE_CONFIG_TIME
5360 async def _achat_completion_stream (self , messages : list [dict ], timeout : int = USE_CONFIG_TIMEOUT ) -> str :
5461 stream = await self .aclient .messages .create (** self ._const_kwargs (messages , stream = True ))
5562 collected_content = []
63+ collected_reasoning_content = []
5664 usage = Usage (input_tokens = 0 , output_tokens = 0 )
5765 async for event in stream :
5866 event_type = event .type
5967 if event_type == "message_start" :
6068 usage .input_tokens = event .message .usage .input_tokens
6169 usage .output_tokens = event .message .usage .output_tokens
6270 elif event_type == "content_block_delta" :
63- content = event .delta .text
64- log_llm_stream (content )
65- collected_content .append (content )
71+ delta_type = event .delta .type
72+ if delta_type == "thinking_delta" :
73+ collected_reasoning_content .append (event .delta .thinking )
74+ elif delta_type == "text_delta" :
75+ content = event .delta .text
76+ log_llm_stream (content )
77+ collected_content .append (content )
6678 elif event_type == "message_delta" :
6779 usage .output_tokens = event .usage .output_tokens # update final output_tokens
6880
6981 log_llm_stream ("\n " )
7082 self ._update_costs (usage )
7183 full_content = "" .join (collected_content )
84+ if collected_reasoning_content :
85+ self .reasoning_content = "" .join (collected_reasoning_content )
7286 return full_content
0 commit comments