Skip to content

Commit b9d5d41

Browse files
committed
feat: Support reasoning content(WIP)
1 parent 808fc7c commit b9d5d41

File tree

1 file changed

+80
-9
lines changed

1 file changed

+80
-9
lines changed

apps/setting/models_provider/impl/base_chat_open_ai.py

Lines changed: 80 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
# coding=utf-8
2+
import warnings
3+
from typing import List, Dict, Optional, Any, Iterator, cast, Type
24

3-
from typing import List, Dict, Optional, Any, Iterator, cast
4-
5+
import openai
6+
from langchain_core.callbacks import CallbackManagerForLLMRun
57
from langchain_core.language_models import LanguageModelInput
6-
from langchain_core.messages import BaseMessage, get_buffer_string
8+
from langchain_core.messages import BaseMessage, get_buffer_string, BaseMessageChunk, AIMessageChunk
79
from langchain_core.outputs import ChatGenerationChunk, ChatGeneration
810
from langchain_core.runnables import RunnableConfig, ensure_config
11+
from langchain_core.utils.pydantic import is_basemodel_subclass
912
from langchain_openai import ChatOpenAI
13+
from langchain_openai.chat_models.base import _convert_chunk_to_generation_chunk
1014

1115
from common.config.tokenizer_manage_config import TokenizerManage
1216

@@ -36,14 +40,81 @@ def get_num_tokens(self, text: str) -> int:
3640
return self.get_last_generation_info().get('output_tokens', 0)
3741

3842
def _stream(
39-
self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any
43+
self,
44+
messages: List[BaseMessage],
45+
stop: Optional[List[str]] = None,
46+
run_manager: Optional[CallbackManagerForLLMRun] = None,
47+
**kwargs: Any,
4048
) -> Iterator[ChatGenerationChunk]:
49+
50+
"""Set default stream_options."""
51+
stream_usage = self._should_stream_usage(kwargs.get('stream_usage'), **kwargs)
52+
# Note: stream_options is not a valid parameter for Azure OpenAI.
53+
# To support users proxying Azure through ChatOpenAI, here we only specify
54+
# stream_options if include_usage is set to True.
55+
# See https://learn.microsoft.com/en-us/azure/ai-services/openai/whats-new
56+
# for release notes.
57+
if stream_usage:
58+
kwargs["stream_options"] = {"include_usage": stream_usage}
59+
4160
kwargs["stream"] = True
42-
kwargs["stream_options"] = {"include_usage": True}
43-
for chunk in super()._stream(*args, stream_usage=stream_usage, **kwargs):
44-
if chunk.message.usage_metadata is not None:
45-
self.usage_metadata = chunk.message.usage_metadata
46-
yield chunk
61+
payload = self._get_request_payload(messages, stop=stop, **kwargs)
62+
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
63+
base_generation_info = {}
64+
65+
if "response_format" in payload and is_basemodel_subclass(
66+
payload["response_format"]
67+
):
68+
# TODO: Add support for streaming with Pydantic response_format.
69+
warnings.warn("Streaming with Pydantic response_format not yet supported.")
70+
chat_result = self._generate(
71+
messages, stop, run_manager=run_manager, **kwargs
72+
)
73+
msg = chat_result.generations[0].message
74+
yield ChatGenerationChunk(
75+
message=AIMessageChunk(
76+
**msg.dict(exclude={"type", "additional_kwargs"}),
77+
# preserve the "parsed" Pydantic object without converting to dict
78+
additional_kwargs=msg.additional_kwargs,
79+
),
80+
generation_info=chat_result.generations[0].generation_info,
81+
)
82+
return
83+
if self.include_response_headers:
84+
raw_response = self.client.with_raw_response.create(**payload)
85+
response = raw_response.parse()
86+
base_generation_info = {"headers": dict(raw_response.headers)}
87+
else:
88+
response = self.client.create(**payload)
89+
with response:
90+
is_first_chunk = True
91+
for chunk in response:
92+
if not isinstance(chunk, dict):
93+
chunk = chunk.model_dump()
94+
95+
generation_chunk = _convert_chunk_to_generation_chunk(
96+
chunk,
97+
default_chunk_class,
98+
base_generation_info if is_first_chunk else {},
99+
)
100+
if generation_chunk is None:
101+
continue
102+
103+
# custom code
104+
if generation_chunk.message.usage_metadata is not None:
105+
self.usage_metadata = generation_chunk.message.usage_metadata
106+
# custom code
107+
if chunk['choices'][0]['delta']['reasoning_content']:
108+
generation_chunk.message.additional_kwargs["reasoning_content"] = chunk['choices'][0]['delta']['reasoning_content']
109+
110+
default_chunk_class = generation_chunk.message.__class__
111+
logprobs = (generation_chunk.generation_info or {}).get("logprobs")
112+
if run_manager:
113+
run_manager.on_llm_new_token(
114+
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
115+
)
116+
is_first_chunk = False
117+
yield generation_chunk
47118

48119
def invoke(
49120
self,

0 commit comments

Comments
 (0)