11import json
2- from typing import Literal , Tuple
2+ from typing import Literal , Tuple , Union
33
44from metagpt .provider .bedrock .base_provider import BaseBedrockProvider
55from metagpt .provider .bedrock .utils import (
@@ -32,6 +32,10 @@ def _split_system_user_messages(self, messages: list[dict]) -> Tuple[str, list[d
3232 return self .messages_to_prompt (system_messages ), user_messages
3333
3434 def get_request_body (self , messages : list [dict ], generate_kwargs , * args , ** kwargs ) -> str :
35+ if self .reasoning :
36+ generate_kwargs ["temperature" ] = 1 # should be 1
37+ generate_kwargs ["thinking" ] = {"type" : "enabled" , "budget_tokens" : self .reasoning_tokens }
38+
3539 system_message , user_messages = self ._split_system_user_messages (messages )
3640 body = json .dumps (
3741 {
@@ -43,17 +47,26 @@ def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwarg
4347 )
4448 return body
4549
46- def _get_completion_from_dict (self , rsp_dict : dict ) -> str :
50+ def _get_completion_from_dict (self , rsp_dict : dict ) -> dict [str , Tuple [str , str ]]:
51+ if self .reasoning :
52+ return {"reasoning_content" : rsp_dict ["content" ][0 ]["thinking" ], "content" : rsp_dict ["content" ][1 ]["text" ]}
4753 return rsp_dict ["content" ][0 ]["text" ]
4854
49- def get_choice_text_from_stream (self , event ) -> str :
55+ def get_choice_text_from_stream (self , event ) -> Union [ bool , str ] :
5056 # https://docs.anthropic.com/claude/reference/messages-streaming
5157 rsp_dict = json .loads (event ["chunk" ]["bytes" ])
5258 if rsp_dict ["type" ] == "content_block_delta" :
53- completions = rsp_dict ["delta" ]["text" ]
54- return completions
59+ reasoning = False
60+ if rsp_dict ["delta" ]["type" ] == "text_delta" :
61+ completions = rsp_dict ["delta" ]["text" ]
62+ elif rsp_dict ["delta" ]["type" ] == "thinking_delta" :
63+ completions = rsp_dict ["delta" ]["thinking" ]
64+ reasoning = True
65+ elif rsp_dict ["delta" ]["type" ] == "signature_delta" :
66+ completions = ""
67+ return reasoning , completions
5568 else :
56- return ""
69+ return False , ""
5770
5871
5972class CohereProvider (BaseBedrockProvider ):
@@ -87,10 +100,10 @@ def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwarg
87100 body = json .dumps ({"prompt" : prompt , "stream" : kwargs .get ("stream" , False ), ** generate_kwargs })
88101 return body
89102
90- def get_choice_text_from_stream (self , event ) -> str :
103+ def get_choice_text_from_stream (self , event ) -> Union [ bool , str ] :
91104 rsp_dict = json .loads (event ["chunk" ]["bytes" ])
92105 completions = rsp_dict .get ("text" , "" )
93- return completions
106+ return False , completions
94107
95108
96109class MetaProvider (BaseBedrockProvider ):
@@ -133,10 +146,10 @@ def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwarg
133146 )
134147 return body
135148
136- def get_choice_text_from_stream (self , event ) -> str :
149+ def get_choice_text_from_stream (self , event ) -> Union [ bool , str ] :
137150 rsp_dict = json .loads (event ["chunk" ]["bytes" ])
138151 completions = rsp_dict .get ("choices" , [{}])[0 ].get ("delta" , {}).get ("content" , "" )
139- return completions
152+ return False , completions
140153
141154 def _get_completion_from_dict (self , rsp_dict : dict ) -> str :
142155 if self .model_type == "j2" :
@@ -159,10 +172,10 @@ def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwarg
159172 def _get_completion_from_dict (self , rsp_dict : dict ) -> str :
160173 return rsp_dict ["results" ][0 ]["outputText" ]
161174
162- def get_choice_text_from_stream (self , event ) -> str :
175+ def get_choice_text_from_stream (self , event ) -> Union [ bool , str ] :
163176 rsp_dict = json .loads (event ["chunk" ]["bytes" ])
164177 completions = rsp_dict ["outputText" ]
165- return completions
178+ return False , completions
166179
167180
168181PROVIDERS = {
@@ -175,8 +188,14 @@ def get_choice_text_from_stream(self, event) -> str:
175188}
176189
177190
178- def get_provider (model_id : str ):
179- provider , model_name = model_id .split ("." )[0 :2 ] # meta、mistral……
191+ def get_provider (model_id : str , reasoning : bool = False , reasoning_tokens : int = 4000 ):
192+ arr = model_id .split ("." )
193+ if len (arr ) == 2 :
194+ provider , model_name = arr # meta、mistral……
195+ elif len (arr ) == 3 :
196+ # some model_ids may contain country like us.xx.xxx
197+ _ , provider , model_name = arr
198+
180199 if provider not in PROVIDERS :
181200 raise KeyError (f"{ provider } is not supported!" )
182201 if provider == "meta" :
@@ -188,4 +207,4 @@ def get_provider(model_id: str):
188207 elif provider == "cohere" :
189208 # distinguish between R/R+ and older models
190209 return PROVIDERS [provider ](model_name )
191- return PROVIDERS [provider ]()
210+ return PROVIDERS [provider ](reasoning = reasoning , reasoning_tokens = reasoning_tokens )
0 commit comments