11import json
2- from typing import Literal , Tuple
2+ from typing import Literal , Tuple , Union
33
44from metagpt .provider .bedrock .base_provider import BaseBedrockProvider
55from metagpt .provider .bedrock .utils import (
@@ -20,6 +20,8 @@ def _get_completion_from_dict(self, rsp_dict: dict) -> str:
2020
2121class AnthropicProvider (BaseBedrockProvider ):
2222 # See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
23+ # https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-37.html
24+ # https://docs.aws.amazon.com/code-library/latest/ug/python_3_bedrock-runtime_code_examples.html#anthropic_claude
2325
2426 def _split_system_user_messages (self , messages : list [dict ]) -> Tuple [str , list [dict ]]:
2527 system_messages = []
@@ -32,6 +34,10 @@ def _split_system_user_messages(self, messages: list[dict]) -> Tuple[str, list[d
3234 return self .messages_to_prompt (system_messages ), user_messages
3335
3436 def get_request_body (self , messages : list [dict ], generate_kwargs , * args , ** kwargs ) -> str :
37+ if self .reasoning :
38+ generate_kwargs ["temperature" ] = 1 # should be 1
39+ generate_kwargs ["thinking" ] = {"type" : "enabled" , "budget_tokens" : self .reasoning_max_token }
40+
3541 system_message , user_messages = self ._split_system_user_messages (messages )
3642 body = json .dumps (
3743 {
@@ -43,17 +49,27 @@ def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwarg
4349 )
4450 return body
4551
46- def _get_completion_from_dict (self , rsp_dict : dict ) -> str :
52+ def _get_completion_from_dict (self , rsp_dict : dict ) -> dict [str , Tuple [str , str ]]:
53+ if self .reasoning :
54+ return {"reasoning_content" : rsp_dict ["content" ][0 ]["thinking" ], "content" : rsp_dict ["content" ][1 ]["text" ]}
4755 return rsp_dict ["content" ][0 ]["text" ]
4856
49- def get_choice_text_from_stream (self , event ) -> str :
57+ def get_choice_text_from_stream (self , event ) -> Union [ bool , str ] :
5058 # https://docs.anthropic.com/claude/reference/messages-streaming
5159 rsp_dict = json .loads (event ["chunk" ]["bytes" ])
5260 if rsp_dict ["type" ] == "content_block_delta" :
53- completions = rsp_dict ["delta" ]["text" ]
54- return completions
61+ reasoning = False
62+ delta_type = rsp_dict ["delta" ]["type" ]
63+ if delta_type == "text_delta" :
64+ completions = rsp_dict ["delta" ]["text" ]
65+ elif delta_type == "thinking_delta" :
66+ completions = rsp_dict ["delta" ]["thinking" ]
67+ reasoning = True
68+ elif delta_type == "signature_delta" :
69+ completions = ""
70+ return reasoning , completions
5571 else :
56- return ""
72+ return False , ""
5773
5874
5975class CohereProvider (BaseBedrockProvider ):
@@ -87,10 +103,10 @@ def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwarg
87103 body = json .dumps ({"prompt" : prompt , "stream" : kwargs .get ("stream" , False ), ** generate_kwargs })
88104 return body
89105
90- def get_choice_text_from_stream (self , event ) -> str :
106+ def get_choice_text_from_stream (self , event ) -> Union [ bool , str ] :
91107 rsp_dict = json .loads (event ["chunk" ]["bytes" ])
92108 completions = rsp_dict .get ("text" , "" )
93- return completions
109+ return False , completions
94110
95111
96112class MetaProvider (BaseBedrockProvider ):
@@ -133,10 +149,10 @@ def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwarg
133149 )
134150 return body
135151
136- def get_choice_text_from_stream (self , event ) -> str :
152+ def get_choice_text_from_stream (self , event ) -> Union [ bool , str ] :
137153 rsp_dict = json .loads (event ["chunk" ]["bytes" ])
138154 completions = rsp_dict .get ("choices" , [{}])[0 ].get ("delta" , {}).get ("content" , "" )
139- return completions
155+ return False , completions
140156
141157 def _get_completion_from_dict (self , rsp_dict : dict ) -> str :
142158 if self .model_type == "j2" :
@@ -159,10 +175,10 @@ def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwarg
159175 def _get_completion_from_dict (self , rsp_dict : dict ) -> str :
160176 return rsp_dict ["results" ][0 ]["outputText" ]
161177
162- def get_choice_text_from_stream (self , event ) -> str :
178+ def get_choice_text_from_stream (self , event ) -> Union [ bool , str ] :
163179 rsp_dict = json .loads (event ["chunk" ]["bytes" ])
164180 completions = rsp_dict ["outputText" ]
165- return completions
181+ return False , completions
166182
167183
168184PROVIDERS = {
@@ -175,8 +191,14 @@ def get_choice_text_from_stream(self, event) -> str:
175191}
176192
177193
178- def get_provider (model_id : str ):
179- provider , model_name = model_id .split ("." )[0 :2 ] # meta、mistral……
194+ def get_provider (model_id : str , reasoning : bool = False , reasoning_max_token : int = 4000 ):
195+ arr = model_id .split ("." )
196+ if len (arr ) == 2 :
197+ provider , model_name = arr # meta、mistral……
198+ elif len (arr ) == 3 :
199+ # some model_ids may contain country like us.xx.xxx
200+ _ , provider , model_name = arr
201+
180202 if provider not in PROVIDERS :
181203 raise KeyError (f"{ provider } is not supported!" )
182204 if provider == "meta" :
@@ -188,4 +210,4 @@ def get_provider(model_id: str):
188210 elif provider == "cohere" :
189211 # distinguish between R/R+ and older models
190212 return PROVIDERS [provider ](model_name )
191- return PROVIDERS [provider ]()
213+ return PROVIDERS [provider ](reasoning = reasoning , reasoning_max_token = reasoning_max_token )
0 commit comments