Skip to content

Commit 55726e0

Browse files
authored
Merge pull request #1726 from better629/reasoning
add LLM Reasoning models
2 parents 44eccf0 + 08587f3 commit 55726e0

File tree

15 files changed

+233
-68
lines changed

15 files changed

+233
-68
lines changed

examples/hello_world.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313

1414
async def ask_and_print(question: str, llm: LLM, system_prompt) -> str:
1515
logger.info(f"Q: {question}")
16-
rsp = await llm.aask(question, system_msgs=[system_prompt])
16+
rsp = await llm.aask(question, system_msgs=[system_prompt], stream=True)
17+
if llm.reasoning_content:
18+
logger.info(f"A reasoning: {llm.reasoning_content}")
1719
logger.info(f"A: {rsp}")
1820
return rsp
1921

metagpt/configs/llm_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class LLMType(Enum):
4040
DEEPSEEK = "deepseek"
4141
SILICONFLOW = "siliconflow"
4242
OPENROUTER = "openrouter"
43+
OPENROUTER_REASONING = "openrouter_reasoning"
4344
BEDROCK = "bedrock"
4445
ARK = "ark" # https://www.volcengine.com/docs/82379/1263482#python-sdk
4546

@@ -107,6 +108,10 @@ class LLMConfig(YamlModel):
107108
# For Messages Control
108109
use_system_prompt: bool = True
109110

111+
# reasoning / thinking switch
112+
reasoning: bool = False
113+
reasoning_max_token: int = 4000 # reasoning budget tokens to generate, usually smaller than max_token
114+
110115
@field_validator("api_key")
111116
@classmethod
112117
def check_llm_key(cls, v):

metagpt/provider/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from metagpt.provider.anthropic_api import AnthropicLLM
2020
from metagpt.provider.bedrock_api import BedrockLLM
2121
from metagpt.provider.ark_api import ArkLLM
22+
from metagpt.provider.openrouter_reasoning import OpenrouterReasoningLLM
2223

2324
__all__ = [
2425
"GeminiLLM",
@@ -34,4 +35,5 @@
3435
"AnthropicLLM",
3536
"BedrockLLM",
3637
"ArkLLM",
38+
"OpenrouterReasoningLLM",
3739
]

metagpt/provider/anthropic_api.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,21 @@ def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
3333
if messages[0]["role"] == "system":
3434
kwargs["messages"] = messages[1:]
3535
kwargs["system"] = messages[0]["content"] # set system prompt here
36+
if self.config.reasoning:
37+
kwargs["thinking"] = {"type": "enabled", "budget_tokens": self.config.reasoning_max_token}
3638
return kwargs
3739

3840
def _update_costs(self, usage: Usage, model: str = None, local_calc_usage: bool = True):
3941
usage = {"prompt_tokens": usage.input_tokens, "completion_tokens": usage.output_tokens}
4042
super()._update_costs(usage, model)
4143

4244
def get_choice_text(self, resp: Message) -> str:
43-
return resp.content[0].text
45+
if len(resp.content) > 1:
46+
self.reasoning_content = resp.content[0].thinking
47+
text = resp.content[1].text
48+
else:
49+
text = resp.content[0].text
50+
return text
4451

4552
async def _achat_completion(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> Message:
4653
resp: Message = await self.aclient.messages.create(**self._const_kwargs(messages))
@@ -53,20 +60,27 @@ async def acompletion(self, messages: list[dict], timeout: int = USE_CONFIG_TIME
5360
async def _achat_completion_stream(self, messages: list[dict], timeout: int = USE_CONFIG_TIMEOUT) -> str:
5461
stream = await self.aclient.messages.create(**self._const_kwargs(messages, stream=True))
5562
collected_content = []
63+
collected_reasoning_content = []
5664
usage = Usage(input_tokens=0, output_tokens=0)
5765
async for event in stream:
5866
event_type = event.type
5967
if event_type == "message_start":
6068
usage.input_tokens = event.message.usage.input_tokens
6169
usage.output_tokens = event.message.usage.output_tokens
6270
elif event_type == "content_block_delta":
63-
content = event.delta.text
64-
log_llm_stream(content)
65-
collected_content.append(content)
71+
delta_type = event.delta.type
72+
if delta_type == "thinking_delta":
73+
collected_reasoning_content.append(event.delta.thinking)
74+
elif delta_type == "text_delta":
75+
content = event.delta.text
76+
log_llm_stream(content)
77+
collected_content.append(content)
6678
elif event_type == "message_delta":
6779
usage.output_tokens = event.usage.output_tokens # update final output_tokens
6880

6981
log_llm_stream("\n")
7082
self._update_costs(usage)
7183
full_content = "".join(collected_content)
84+
if collected_reasoning_content:
85+
self.reasoning_content = "".join(collected_reasoning_content)
7286
return full_content

metagpt/provider/base_llm.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,16 @@ class BaseLLM(ABC):
4545
model: Optional[str] = None # deprecated
4646
pricing_plan: Optional[str] = None
4747

48+
_reasoning_content: Optional[str] = None # content from reasoning mode
49+
50+
@property
51+
def reasoning_content(self):
52+
return self._reasoning_content
53+
54+
@reasoning_content.setter
55+
def reasoning_content(self, value: str):
56+
self._reasoning_content = value
57+
4858
@abstractmethod
4959
def __init__(self, config: LLMConfig):
5060
pass
@@ -216,7 +226,10 @@ async def acompletion_text(
216226

217227
def get_choice_text(self, rsp: dict) -> str:
218228
"""Required to provide the first text of choice"""
219-
return rsp.get("choices")[0]["message"]["content"]
229+
message = rsp.get("choices")[0]["message"]
230+
if "reasoning_content" in message:
231+
self.reasoning_content = message["reasoning_content"]
232+
return message["content"]
220233

221234
def get_choice_delta_text(self, rsp: dict) -> str:
222235
"""Required to provide the first text of stream choice"""

metagpt/provider/bedrock/base_provider.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
import json
22
from abc import ABC, abstractmethod
3+
from typing import Union
34

45

56
class BaseBedrockProvider(ABC):
67
# to handle different generation kwargs
78
max_tokens_field_name = "max_tokens"
89

10+
def __init__(self, reasoning: bool = False, reasoning_max_token: int = 4000):
11+
self.reasoning = reasoning
12+
self.reasoning_max_token = reasoning_max_token
13+
914
@abstractmethod
1015
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
1116
...
@@ -14,14 +19,14 @@ def get_request_body(self, messages: list[dict], const_kwargs, *args, **kwargs)
1419
body = json.dumps({"prompt": self.messages_to_prompt(messages), **const_kwargs})
1520
return body
1621

17-
def get_choice_text(self, response_body: dict) -> str:
22+
def get_choice_text(self, response_body: dict) -> Union[str, dict[str, str]]:
1823
completions = self._get_completion_from_dict(response_body)
1924
return completions
2025

21-
def get_choice_text_from_stream(self, event) -> str:
26+
def get_choice_text_from_stream(self, event) -> Union[bool, str]:
2227
rsp_dict = json.loads(event["chunk"]["bytes"])
2328
completions = self._get_completion_from_dict(rsp_dict)
24-
return completions
29+
return False, completions
2530

2631
def messages_to_prompt(self, messages: list[dict]) -> str:
2732
"""[{"role": "user", "content": msg}] to user: <msg> etc."""

metagpt/provider/bedrock/bedrock_provider.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import json
2-
from typing import Literal, Tuple
2+
from typing import Literal, Tuple, Union
33

44
from metagpt.provider.bedrock.base_provider import BaseBedrockProvider
55
from metagpt.provider.bedrock.utils import (
@@ -20,6 +20,8 @@ def _get_completion_from_dict(self, rsp_dict: dict) -> str:
2020

2121
class AnthropicProvider(BaseBedrockProvider):
2222
# See https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-messages.html
23+
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-anthropic-claude-37.html
24+
# https://docs.aws.amazon.com/code-library/latest/ug/python_3_bedrock-runtime_code_examples.html#anthropic_claude
2325

2426
def _split_system_user_messages(self, messages: list[dict]) -> Tuple[str, list[dict]]:
2527
system_messages = []
@@ -32,6 +34,10 @@ def _split_system_user_messages(self, messages: list[dict]) -> Tuple[str, list[d
3234
return self.messages_to_prompt(system_messages), user_messages
3335

3436
def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwargs) -> str:
37+
if self.reasoning:
38+
generate_kwargs["temperature"] = 1 # should be 1
39+
generate_kwargs["thinking"] = {"type": "enabled", "budget_tokens": self.reasoning_max_token}
40+
3541
system_message, user_messages = self._split_system_user_messages(messages)
3642
body = json.dumps(
3743
{
@@ -43,17 +49,27 @@ def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwarg
4349
)
4450
return body
4551

46-
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
52+
def _get_completion_from_dict(self, rsp_dict: dict) -> dict[str, Tuple[str, str]]:
53+
if self.reasoning:
54+
return {"reasoning_content": rsp_dict["content"][0]["thinking"], "content": rsp_dict["content"][1]["text"]}
4755
return rsp_dict["content"][0]["text"]
4856

49-
def get_choice_text_from_stream(self, event) -> str:
57+
def get_choice_text_from_stream(self, event) -> Union[bool, str]:
5058
# https://docs.anthropic.com/claude/reference/messages-streaming
5159
rsp_dict = json.loads(event["chunk"]["bytes"])
5260
if rsp_dict["type"] == "content_block_delta":
53-
completions = rsp_dict["delta"]["text"]
54-
return completions
61+
reasoning = False
62+
delta_type = rsp_dict["delta"]["type"]
63+
if delta_type == "text_delta":
64+
completions = rsp_dict["delta"]["text"]
65+
elif delta_type == "thinking_delta":
66+
completions = rsp_dict["delta"]["thinking"]
67+
reasoning = True
68+
elif delta_type == "signature_delta":
69+
completions = ""
70+
return reasoning, completions
5571
else:
56-
return ""
72+
return False, ""
5773

5874

5975
class CohereProvider(BaseBedrockProvider):
@@ -87,10 +103,10 @@ def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwarg
87103
body = json.dumps({"prompt": prompt, "stream": kwargs.get("stream", False), **generate_kwargs})
88104
return body
89105

90-
def get_choice_text_from_stream(self, event) -> str:
106+
def get_choice_text_from_stream(self, event) -> Union[bool, str]:
91107
rsp_dict = json.loads(event["chunk"]["bytes"])
92108
completions = rsp_dict.get("text", "")
93-
return completions
109+
return False, completions
94110

95111

96112
class MetaProvider(BaseBedrockProvider):
@@ -133,10 +149,10 @@ def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwarg
133149
)
134150
return body
135151

136-
def get_choice_text_from_stream(self, event) -> str:
152+
def get_choice_text_from_stream(self, event) -> Union[bool, str]:
137153
rsp_dict = json.loads(event["chunk"]["bytes"])
138154
completions = rsp_dict.get("choices", [{}])[0].get("delta", {}).get("content", "")
139-
return completions
155+
return False, completions
140156

141157
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
142158
if self.model_type == "j2":
@@ -159,10 +175,10 @@ def get_request_body(self, messages: list[dict], generate_kwargs, *args, **kwarg
159175
def _get_completion_from_dict(self, rsp_dict: dict) -> str:
160176
return rsp_dict["results"][0]["outputText"]
161177

162-
def get_choice_text_from_stream(self, event) -> str:
178+
def get_choice_text_from_stream(self, event) -> Union[bool, str]:
163179
rsp_dict = json.loads(event["chunk"]["bytes"])
164180
completions = rsp_dict["outputText"]
165-
return completions
181+
return False, completions
166182

167183

168184
PROVIDERS = {
@@ -175,8 +191,14 @@ def get_choice_text_from_stream(self, event) -> str:
175191
}
176192

177193

178-
def get_provider(model_id: str):
179-
provider, model_name = model_id.split(".")[0:2] # meta、mistral……
194+
def get_provider(model_id: str, reasoning: bool = False, reasoning_max_token: int = 4000):
195+
arr = model_id.split(".")
196+
if len(arr) == 2:
197+
provider, model_name = arr # meta、mistral……
198+
elif len(arr) == 3:
199+
# some model_ids may contain country like us.xx.xxx
200+
_, provider, model_name = arr
201+
180202
if provider not in PROVIDERS:
181203
raise KeyError(f"{provider} is not supported!")
182204
if provider == "meta":
@@ -188,4 +210,4 @@ def get_provider(model_id: str):
188210
elif provider == "cohere":
189211
# distinguish between R/R+ and older models
190212
return PROVIDERS[provider](model_name)
191-
return PROVIDERS[provider]()
213+
return PROVIDERS[provider](reasoning=reasoning, reasoning_max_token=reasoning_max_token)

metagpt/provider/bedrock/utils.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,9 @@
4848
"anthropic.claude-3-opus-20240229-v1:0": 4096,
4949
# Claude 3.5 Sonnet
5050
"anthropic.claude-3-5-sonnet-20240620-v1:0": 8192,
51+
# Claude 3.7 Sonnet
52+
"us.anthropic.claude-3-7-sonnet-20250219-v1:0": 131072,
53+
"anthropic.claude-3-7-sonnet-20250219-v1:0": 131072,
5154
# Command Text
5255
# https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-cohere-command.html
5356
"cohere.command-text-v14": 4096,
@@ -135,20 +138,6 @@ def messages_to_prompt_llama3(messages: list[dict]) -> str:
135138
return prompt
136139

137140

138-
def messages_to_prompt_claude2(messages: list[dict]) -> str:
139-
GENERAL_TEMPLATE = "\n\n{role}: {content}"
140-
prompt = ""
141-
for message in messages:
142-
role = message.get("role", "")
143-
content = message.get("content", "")
144-
prompt += GENERAL_TEMPLATE.format(role=role, content=content)
145-
146-
if role != "assistant":
147-
prompt += "\n\nAssistant:"
148-
149-
return prompt
150-
151-
152141
def get_max_tokens(model_id: str) -> int:
153142
try:
154143
max_tokens = (NOT_SUPPORT_STREAM_MODELS | SUPPORT_STREAM_MODELS)[model_id]

metagpt/provider/bedrock_api.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ class BedrockLLM(BaseLLM):
2323
def __init__(self, config: LLMConfig):
2424
self.config = config
2525
self.__client = self.__init_client("bedrock-runtime")
26-
self.__provider = get_provider(self.config.model)
26+
self.__provider = get_provider(
27+
self.config.model, reasoning=self.config.reasoning, reasoning_max_token=self.config.reasoning_max_token
28+
)
2729
self.cost_manager = CostManager(token_costs=BEDROCK_TOKEN_COSTS)
2830
if self.config.model in NOT_SUPPORT_STREAM_MODELS:
2931
logger.warning(f"model {self.config.model} doesn't support streaming output!")
@@ -102,7 +104,11 @@ def _const_kwargs(self) -> dict:
102104
# However,aioboto3 doesn't support invoke model
103105

104106
def get_choice_text(self, rsp: dict) -> str:
105-
return self.__provider.get_choice_text(rsp)
107+
rsp = self.__provider.get_choice_text(rsp)
108+
if isinstance(rsp, dict):
109+
self.reasoning_content = rsp.get("reasoning_content")
110+
rsp = rsp.get("content")
111+
return rsp
106112

107113
async def acompletion(self, messages: list[dict]) -> dict:
108114
request_body = self.__provider.get_request_body(messages, self._const_kwargs)
@@ -133,10 +139,16 @@ def _get_response_body(self, response) -> dict:
133139
async def _get_stream_response_body(self, stream_response) -> List[str]:
134140
def collect_content() -> str:
135141
collected_content = []
142+
collected_reasoning_content = []
136143
for event in stream_response["body"]:
137-
chunk_text = self.__provider.get_choice_text_from_stream(event)
138-
collected_content.append(chunk_text)
139-
log_llm_stream(chunk_text)
144+
reasoning, chunk_text = self.__provider.get_choice_text_from_stream(event)
145+
if reasoning:
146+
collected_reasoning_content.append(chunk_text)
147+
else:
148+
collected_content.append(chunk_text)
149+
log_llm_stream(chunk_text)
150+
if collected_reasoning_content:
151+
self.reasoning_content = "".join(collected_reasoning_content)
140152
return collected_content
141153

142154
loop = asyncio.get_running_loop()

metagpt/provider/general_api_base.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,14 @@ def response_ms(self) -> Optional[int]:
150150
h = self._headers.get("Openai-Processing-Ms")
151151
return None if h is None else round(float(h))
152152

153+
def decode_asjson(self) -> Optional[dict]:
154+
bstr = self.data.strip()
155+
if bstr.startswith(b"{") and bstr.endswith(b"}"):
156+
bstr = bstr.decode("utf-8")
157+
else:
158+
bstr = parse_stream_helper(bstr)
159+
return json.loads(bstr) if bstr else None
160+
153161

154162
def _build_api_url(url, query):
155163
scheme, netloc, path, base_query, fragment = urlsplit(url)
@@ -547,13 +555,6 @@ async def arequest_raw(
547555
}
548556
try:
549557
result = await session.request(**request_kwargs)
550-
# log_info(
551-
# "LLM API response",
552-
# path=abs_url,
553-
# response_code=result.status,
554-
# processing_ms=result.headers.get("LLM-Processing-Ms"),
555-
# request_id=result.headers.get("X-Request-Id"),
556-
# )
557558
return result
558559
except (aiohttp.ServerTimeoutError, asyncio.TimeoutError) as e:
559560
raise openai.APITimeoutError("Request timed out") from e

0 commit comments

Comments
 (0)