11# coding=utf-8
2- import warnings
3- from typing import List , Dict , Optional , Any , Iterator , cast , Type , Union
2+ from typing import Dict , Optional , Any , Iterator , cast , Union , Sequence , Callable , Mapping
43
5- import openai
6- from langchain_core .callbacks import CallbackManagerForLLMRun
74from langchain_core .language_models import LanguageModelInput
8- from langchain_core .messages import BaseMessage , get_buffer_string , BaseMessageChunk , AIMessageChunk
9- from langchain_core .outputs import ChatGenerationChunk , ChatGeneration
5+ from langchain_core .messages import BaseMessage , get_buffer_string , BaseMessageChunk , HumanMessageChunk , AIMessageChunk , \
6+ SystemMessageChunk , FunctionMessageChunk , ChatMessageChunk
7+ from langchain_core .messages .ai import UsageMetadata
8+ from langchain_core .messages .tool import tool_call_chunk , ToolMessageChunk
9+ from langchain_core .outputs import ChatGenerationChunk
1010from langchain_core .runnables import RunnableConfig , ensure_config
11- from langchain_core .utils . pydantic import is_basemodel_subclass
11+ from langchain_core .tools import BaseTool
1212from langchain_openai import ChatOpenAI
13+ from langchain_openai .chat_models .base import _create_usage_metadata
1314
1415from common .config .tokenizer_manage_config import TokenizerManage
1516
@@ -19,14 +20,78 @@ def custom_get_token_ids(text: str):
1920 return tokenizer .encode (text )
2021
2122
23+ def _convert_delta_to_message_chunk (
24+ _dict : Mapping [str , Any ], default_class : type [BaseMessageChunk ]
25+ ) -> BaseMessageChunk :
26+ id_ = _dict .get ("id" )
27+ reasoning_content = cast (str , _dict .get ("reasoning_content" ) or "" )
28+ role = cast (str , _dict .get ("role" ))
29+ content = cast (str , _dict .get ("content" ) or "" )
30+ additional_kwargs : dict = {'reasoning_content' : reasoning_content }
31+ if _dict .get ("function_call" ):
32+ function_call = dict (_dict ["function_call" ])
33+ if "name" in function_call and function_call ["name" ] is None :
34+ function_call ["name" ] = ""
35+ additional_kwargs ["function_call" ] = function_call
36+ tool_call_chunks = []
37+ if raw_tool_calls := _dict .get ("tool_calls" ):
38+ additional_kwargs ["tool_calls" ] = raw_tool_calls
39+ try :
40+ tool_call_chunks = [
41+ tool_call_chunk (
42+ name = rtc ["function" ].get ("name" ),
43+ args = rtc ["function" ].get ("arguments" ),
44+ id = rtc .get ("id" ),
45+ index = rtc ["index" ],
46+ )
47+ for rtc in raw_tool_calls
48+ ]
49+ except KeyError :
50+ pass
51+
52+ if role == "user" or default_class == HumanMessageChunk :
53+ return HumanMessageChunk (content = content , id = id_ )
54+ elif role == "assistant" or default_class == AIMessageChunk :
55+ return AIMessageChunk (
56+ content = content ,
57+ additional_kwargs = additional_kwargs ,
58+ id = id_ ,
59+ tool_call_chunks = tool_call_chunks , # type: ignore[arg-type]
60+ )
61+ elif role in ("system" , "developer" ) or default_class == SystemMessageChunk :
62+ if role == "developer" :
63+ additional_kwargs = {"__openai_role__" : "developer" }
64+ else :
65+ additional_kwargs = {}
66+ return SystemMessageChunk (
67+ content = content , id = id_ , additional_kwargs = additional_kwargs
68+ )
69+ elif role == "function" or default_class == FunctionMessageChunk :
70+ return FunctionMessageChunk (content = content , name = _dict ["name" ], id = id_ )
71+ elif role == "tool" or default_class == ToolMessageChunk :
72+ return ToolMessageChunk (
73+ content = content , tool_call_id = _dict ["tool_call_id" ], id = id_
74+ )
75+ elif role or default_class == ChatMessageChunk :
76+ return ChatMessageChunk (content = content , role = role , id = id_ )
77+ else :
78+ return default_class (content = content , id = id_ ) # type: ignore
79+
80+
2281class BaseChatOpenAI (ChatOpenAI ):
2382 usage_metadata : dict = {}
2483 custom_get_token_ids = custom_get_token_ids
2584
2685 def get_last_generation_info (self ) -> Optional [Dict [str , Any ]]:
2786 return self .usage_metadata
2887
29- def get_num_tokens_from_messages (self , messages : List [BaseMessage ]) -> int :
88+ def get_num_tokens_from_messages (
89+ self ,
90+ messages : list [BaseMessage ],
91+ tools : Optional [
92+ Sequence [Union [dict [str , Any ], type , Callable , BaseTool ]]
93+ ] = None ,
94+ ) -> int :
3095 if self .usage_metadata is None or self .usage_metadata == {}:
3196 try :
3297 return super ().get_num_tokens_from_messages (messages )
@@ -44,114 +109,77 @@ def get_num_tokens(self, text: str) -> int:
44109 return len (tokenizer .encode (text ))
45110 return self .get_last_generation_info ().get ('output_tokens' , 0 )
46111
47- def _stream (
112+ def _stream (self , * args : Any , ** kwargs : Any ) -> Iterator [ChatGenerationChunk ]:
113+ kwargs ['stream_usage' ] = True
114+ for chunk in super ()._stream (* args , ** kwargs ):
115+ if chunk .message .usage_metadata is not None :
116+ self .usage_metadata = chunk .message .usage_metadata
117+ yield chunk
118+
119+ def _convert_chunk_to_generation_chunk (
48120 self ,
49- messages : List [BaseMessage ],
50- stop : Optional [List [str ]] = None ,
51- run_manager : Optional [CallbackManagerForLLMRun ] = None ,
52- ** kwargs : Any ,
53- ) -> Iterator [ChatGenerationChunk ]:
54- kwargs ["stream" ] = True
55- kwargs ["stream_options" ] = {"include_usage" : True }
56- """Set default stream_options."""
57- stream_usage = self ._should_stream_usage (kwargs .get ('stream_usage' ), ** kwargs )
58- # Note: stream_options is not a valid parameter for Azure OpenAI.
59- # To support users proxying Azure through ChatOpenAI, here we only specify
60- # stream_options if include_usage is set to True.
61- # See https://learn.microsoft.com/en-us/azure/ai-services/openai/whats-new
62- # for release notes.
63- if stream_usage :
64- kwargs ["stream_options" ] = {"include_usage" : stream_usage }
65-
66- payload = self ._get_request_payload (messages , stop = stop , ** kwargs )
67- default_chunk_class : Type [BaseMessageChunk ] = AIMessageChunk
68- base_generation_info = {}
69-
70- if "response_format" in payload and is_basemodel_subclass (
71- payload ["response_format" ]
72- ):
73- # TODO: Add support for streaming with Pydantic response_format.
74- warnings .warn ("Streaming with Pydantic response_format not yet supported." )
75- chat_result = self ._generate (
76- messages , stop , run_manager = run_manager , ** kwargs
77- )
78- msg = chat_result .generations [0 ].message
79- yield ChatGenerationChunk (
80- message = AIMessageChunk (
81- ** msg .dict (exclude = {"type" , "additional_kwargs" }),
82- # preserve the "parsed" Pydantic object without converting to dict
83- additional_kwargs = msg .additional_kwargs ,
84- ),
85- generation_info = chat_result .generations [0 ].generation_info ,
121+ chunk : dict ,
122+ default_chunk_class : type ,
123+ base_generation_info : Optional [dict ],
124+ ) -> Optional [ChatGenerationChunk ]:
125+ if chunk .get ("type" ) == "content.delta" : # from beta.chat.completions.stream
126+ return None
127+ token_usage = chunk .get ("usage" )
128+ choices = (
129+ chunk .get ("choices" , [])
130+ # from beta.chat.completions.stream
131+ or chunk .get ("chunk" , {}).get ("choices" , [])
132+ )
133+
134+ usage_metadata : Optional [UsageMetadata ] = (
135+ _create_usage_metadata (token_usage ) if token_usage else None
136+ )
137+ if len (choices ) == 0 :
138+ # logprobs is implicitly None
139+ generation_chunk = ChatGenerationChunk (
140+ message = default_chunk_class (content = "" , usage_metadata = usage_metadata )
86141 )
87- return
88- if self .include_response_headers :
89- raw_response = self .client .with_raw_response .create (** payload )
90- response = raw_response .parse ()
91- base_generation_info = {"headers" : dict (raw_response .headers )}
92- else :
93- response = self .client .create (** payload )
94- with response :
95- is_first_chunk = True
96- for chunk in response :
97- if not isinstance (chunk , dict ):
98- chunk = chunk .model_dump ()
99-
100- generation_chunk = super ()._convert_chunk_to_generation_chunk (
101- chunk ,
102- default_chunk_class ,
103- base_generation_info if is_first_chunk else {},
104- )
105- if generation_chunk is None :
106- continue
107-
108- # custom code
109- if len (chunk ['choices' ]) > 0 and 'reasoning_content' in chunk ['choices' ][0 ]['delta' ]:
110- generation_chunk .message .additional_kwargs ["reasoning_content" ] = chunk ['choices' ][0 ]['delta' ][
111- 'reasoning_content' ]
112-
113- default_chunk_class = generation_chunk .message .__class__
114- logprobs = (generation_chunk .generation_info or {}).get ("logprobs" )
115- if run_manager :
116- run_manager .on_llm_new_token (
117- generation_chunk .text , chunk = generation_chunk , logprobs = logprobs
118- )
119- is_first_chunk = False
120- # custom code
121- if generation_chunk .message .usage_metadata is not None :
122- self .usage_metadata = generation_chunk .message .usage_metadata
123- yield generation_chunk
124-
125- def _create_chat_result (self ,
126- response : Union [dict , openai .BaseModel ],
127- generation_info : Optional [Dict ] = None ):
128- result = super ()._create_chat_result (response , generation_info )
129- try :
130- reasoning_content = ''
131- reasoning_content_enable = False
132- for res in response .choices :
133- if 'reasoning_content' in res .message .model_extra :
134- reasoning_content_enable = True
135- _reasoning_content = res .message .model_extra .get ('reasoning_content' )
136- if _reasoning_content is not None :
137- reasoning_content += _reasoning_content
138- if reasoning_content_enable :
139- result .llm_output ['reasoning_content' ] = reasoning_content
140- except Exception as e :
141- pass
142- return result
142+ return generation_chunk
143+
144+ choice = choices [0 ]
145+ if choice ["delta" ] is None :
146+ return None
147+
148+ message_chunk = _convert_delta_to_message_chunk (
149+ choice ["delta" ], default_chunk_class
150+ )
151+ generation_info = {** base_generation_info } if base_generation_info else {}
152+
153+ if finish_reason := choice .get ("finish_reason" ):
154+ generation_info ["finish_reason" ] = finish_reason
155+ if model_name := chunk .get ("model" ):
156+ generation_info ["model_name" ] = model_name
157+ if system_fingerprint := chunk .get ("system_fingerprint" ):
158+ generation_info ["system_fingerprint" ] = system_fingerprint
159+
160+ logprobs = choice .get ("logprobs" )
161+ if logprobs :
162+ generation_info ["logprobs" ] = logprobs
163+
164+ if usage_metadata and isinstance (message_chunk , AIMessageChunk ):
165+ message_chunk .usage_metadata = usage_metadata
166+
167+ generation_chunk = ChatGenerationChunk (
168+ message = message_chunk , generation_info = generation_info or None
169+ )
170+ return generation_chunk
143171
144172 def invoke (
145173 self ,
146174 input : LanguageModelInput ,
147175 config : Optional [RunnableConfig ] = None ,
148176 * ,
149- stop : Optional [List [str ]] = None ,
177+ stop : Optional [list [str ]] = None ,
150178 ** kwargs : Any ,
151179 ) -> BaseMessage :
152180 config = ensure_config (config )
153181 chat_result = cast (
154- ChatGeneration ,
182+ " ChatGeneration" ,
155183 self .generate_prompt (
156184 [self ._convert_input (input )],
157185 stop = stop ,
@@ -162,7 +190,9 @@ def invoke(
162190 run_id = config .pop ("run_id" , None ),
163191 ** kwargs ,
164192 ).generations [0 ][0 ],
193+
165194 ).message
195+
166196 self .usage_metadata = chat_result .response_metadata [
167197 'token_usage' ] if 'token_usage' in chat_result .response_metadata else chat_result .usage_metadata
168198 return chat_result
0 commit comments