11# coding=utf-8
2- import warnings
3- from typing import List , Dict , Optional , Any , Iterator , cast , Type , Union
2+ from typing import Dict , Optional , Any , Iterator , cast , Union , Sequence , Callable , Mapping
43
5- import openai
6- from langchain_core .callbacks import CallbackManagerForLLMRun
74from langchain_core .language_models import LanguageModelInput
8- from langchain_core .messages import BaseMessage , get_buffer_string , BaseMessageChunk , AIMessageChunk
9- from langchain_core .outputs import ChatGenerationChunk , ChatGeneration
5+ from langchain_core .messages import BaseMessage , get_buffer_string , BaseMessageChunk , HumanMessageChunk , AIMessageChunk , \
6+ SystemMessageChunk , FunctionMessageChunk , ChatMessageChunk
7+ from langchain_core .messages .ai import UsageMetadata
8+ from langchain_core .messages .tool import tool_call_chunk , ToolMessageChunk
9+ from langchain_core .outputs import ChatGenerationChunk
1010from langchain_core .runnables import RunnableConfig , ensure_config
11- from langchain_core .utils . pydantic import is_basemodel_subclass
11+ from langchain_core .tools import BaseTool
1212from langchain_openai import ChatOpenAI
13+ from langchain_openai .chat_models .base import _create_usage_metadata
1314
1415from common .config .tokenizer_manage_config import TokenizerManage
1516
@@ -19,14 +20,79 @@ def custom_get_token_ids(text: str):
1920 return tokenizer .encode (text )
2021
2122
23+ def _convert_delta_to_message_chunk (
24+ _dict : Mapping [str , Any ], default_class : type [BaseMessageChunk ]
25+ ) -> BaseMessageChunk :
26+ id_ = _dict .get ("id" )
27+ role = cast (str , _dict .get ("role" ))
28+ content = cast (str , _dict .get ("content" ) or "" )
29+ additional_kwargs : dict = {}
30+ if 'reasoning_content' in _dict :
31+ additional_kwargs ['reasoning_content' ] = _dict .get ('reasoning_content' )
32+ if _dict .get ("function_call" ):
33+ function_call = dict (_dict ["function_call" ])
34+ if "name" in function_call and function_call ["name" ] is None :
35+ function_call ["name" ] = ""
36+ additional_kwargs ["function_call" ] = function_call
37+ tool_call_chunks = []
38+ if raw_tool_calls := _dict .get ("tool_calls" ):
39+ additional_kwargs ["tool_calls" ] = raw_tool_calls
40+ try :
41+ tool_call_chunks = [
42+ tool_call_chunk (
43+ name = rtc ["function" ].get ("name" ),
44+ args = rtc ["function" ].get ("arguments" ),
45+ id = rtc .get ("id" ),
46+ index = rtc ["index" ],
47+ )
48+ for rtc in raw_tool_calls
49+ ]
50+ except KeyError :
51+ pass
52+
53+ if role == "user" or default_class == HumanMessageChunk :
54+ return HumanMessageChunk (content = content , id = id_ )
55+ elif role == "assistant" or default_class == AIMessageChunk :
56+ return AIMessageChunk (
57+ content = content ,
58+ additional_kwargs = additional_kwargs ,
59+ id = id_ ,
60+ tool_call_chunks = tool_call_chunks , # type: ignore[arg-type]
61+ )
62+ elif role in ("system" , "developer" ) or default_class == SystemMessageChunk :
63+ if role == "developer" :
64+ additional_kwargs = {"__openai_role__" : "developer" }
65+ else :
66+ additional_kwargs = {}
67+ return SystemMessageChunk (
68+ content = content , id = id_ , additional_kwargs = additional_kwargs
69+ )
70+ elif role == "function" or default_class == FunctionMessageChunk :
71+ return FunctionMessageChunk (content = content , name = _dict ["name" ], id = id_ )
72+ elif role == "tool" or default_class == ToolMessageChunk :
73+ return ToolMessageChunk (
74+ content = content , tool_call_id = _dict ["tool_call_id" ], id = id_
75+ )
76+ elif role or default_class == ChatMessageChunk :
77+ return ChatMessageChunk (content = content , role = role , id = id_ )
78+ else :
79+ return default_class (content = content , id = id_ ) # type: ignore
80+
81+
2282class BaseChatOpenAI (ChatOpenAI ):
2383 usage_metadata : dict = {}
2484 custom_get_token_ids = custom_get_token_ids
2585
2686 def get_last_generation_info (self ) -> Optional [Dict [str , Any ]]:
2787 return self .usage_metadata
2888
29- def get_num_tokens_from_messages (self , messages : List [BaseMessage ]) -> int :
89+ def get_num_tokens_from_messages (
90+ self ,
91+ messages : list [BaseMessage ],
92+ tools : Optional [
93+ Sequence [Union [dict [str , Any ], type , Callable , BaseTool ]]
94+ ] = None ,
95+ ) -> int :
3096 if self .usage_metadata is None or self .usage_metadata == {}:
3197 try :
3298 return super ().get_num_tokens_from_messages (messages )
@@ -44,114 +110,77 @@ def get_num_tokens(self, text: str) -> int:
44110 return len (tokenizer .encode (text ))
45111 return self .get_last_generation_info ().get ('output_tokens' , 0 )
46112
47- def _stream (
113+ def _stream (self , * args : Any , ** kwargs : Any ) -> Iterator [ChatGenerationChunk ]:
114+ kwargs ['stream_usage' ] = True
115+ for chunk in super ()._stream (* args , ** kwargs ):
116+ if chunk .message .usage_metadata is not None :
117+ self .usage_metadata = chunk .message .usage_metadata
118+ yield chunk
119+
120+ def _convert_chunk_to_generation_chunk (
48121 self ,
49- messages : List [BaseMessage ],
50- stop : Optional [List [str ]] = None ,
51- run_manager : Optional [CallbackManagerForLLMRun ] = None ,
52- ** kwargs : Any ,
53- ) -> Iterator [ChatGenerationChunk ]:
54- kwargs ["stream" ] = True
55- kwargs ["stream_options" ] = {"include_usage" : True }
56- """Set default stream_options."""
57- stream_usage = self ._should_stream_usage (kwargs .get ('stream_usage' ), ** kwargs )
58- # Note: stream_options is not a valid parameter for Azure OpenAI.
59- # To support users proxying Azure through ChatOpenAI, here we only specify
60- # stream_options if include_usage is set to True.
61- # See https://learn.microsoft.com/en-us/azure/ai-services/openai/whats-new
62- # for release notes.
63- if stream_usage :
64- kwargs ["stream_options" ] = {"include_usage" : stream_usage }
65-
66- payload = self ._get_request_payload (messages , stop = stop , ** kwargs )
67- default_chunk_class : Type [BaseMessageChunk ] = AIMessageChunk
68- base_generation_info = {}
69-
70- if "response_format" in payload and is_basemodel_subclass (
71- payload ["response_format" ]
72- ):
73- # TODO: Add support for streaming with Pydantic response_format.
74- warnings .warn ("Streaming with Pydantic response_format not yet supported." )
75- chat_result = self ._generate (
76- messages , stop , run_manager = run_manager , ** kwargs
77- )
78- msg = chat_result .generations [0 ].message
79- yield ChatGenerationChunk (
80- message = AIMessageChunk (
81- ** msg .dict (exclude = {"type" , "additional_kwargs" }),
82- # preserve the "parsed" Pydantic object without converting to dict
83- additional_kwargs = msg .additional_kwargs ,
84- ),
85- generation_info = chat_result .generations [0 ].generation_info ,
122+ chunk : dict ,
123+ default_chunk_class : type ,
124+ base_generation_info : Optional [dict ],
125+ ) -> Optional [ChatGenerationChunk ]:
126+ if chunk .get ("type" ) == "content.delta" : # from beta.chat.completions.stream
127+ return None
128+ token_usage = chunk .get ("usage" )
129+ choices = (
130+ chunk .get ("choices" , [])
131+ # from beta.chat.completions.stream
132+ or chunk .get ("chunk" , {}).get ("choices" , [])
133+ )
134+
135+ usage_metadata : Optional [UsageMetadata ] = (
136+ _create_usage_metadata (token_usage ) if token_usage else None
137+ )
138+ if len (choices ) == 0 :
139+ # logprobs is implicitly None
140+ generation_chunk = ChatGenerationChunk (
141+ message = default_chunk_class (content = "" , usage_metadata = usage_metadata )
86142 )
87- return
88- if self .include_response_headers :
89- raw_response = self .client .with_raw_response .create (** payload )
90- response = raw_response .parse ()
91- base_generation_info = {"headers" : dict (raw_response .headers )}
92- else :
93- response = self .client .create (** payload )
94- with response :
95- is_first_chunk = True
96- for chunk in response :
97- if not isinstance (chunk , dict ):
98- chunk = chunk .model_dump ()
99-
100- generation_chunk = super ()._convert_chunk_to_generation_chunk (
101- chunk ,
102- default_chunk_class ,
103- base_generation_info if is_first_chunk else {},
104- )
105- if generation_chunk is None :
106- continue
107-
108- # custom code
109- if len (chunk ['choices' ]) > 0 and 'reasoning_content' in chunk ['choices' ][0 ]['delta' ]:
110- generation_chunk .message .additional_kwargs ["reasoning_content" ] = chunk ['choices' ][0 ]['delta' ][
111- 'reasoning_content' ]
112-
113- default_chunk_class = generation_chunk .message .__class__
114- logprobs = (generation_chunk .generation_info or {}).get ("logprobs" )
115- if run_manager :
116- run_manager .on_llm_new_token (
117- generation_chunk .text , chunk = generation_chunk , logprobs = logprobs
118- )
119- is_first_chunk = False
120- # custom code
121- if generation_chunk .message .usage_metadata is not None :
122- self .usage_metadata = generation_chunk .message .usage_metadata
123- yield generation_chunk
124-
125- def _create_chat_result (self ,
126- response : Union [dict , openai .BaseModel ],
127- generation_info : Optional [Dict ] = None ):
128- result = super ()._create_chat_result (response , generation_info )
129- try :
130- reasoning_content = ''
131- reasoning_content_enable = False
132- for res in response .choices :
133- if 'reasoning_content' in res .message .model_extra :
134- reasoning_content_enable = True
135- _reasoning_content = res .message .model_extra .get ('reasoning_content' )
136- if _reasoning_content is not None :
137- reasoning_content += _reasoning_content
138- if reasoning_content_enable :
139- result .llm_output ['reasoning_content' ] = reasoning_content
140- except Exception as e :
141- pass
142- return result
143+ return generation_chunk
144+
145+ choice = choices [0 ]
146+ if choice ["delta" ] is None :
147+ return None
148+
149+ message_chunk = _convert_delta_to_message_chunk (
150+ choice ["delta" ], default_chunk_class
151+ )
152+ generation_info = {** base_generation_info } if base_generation_info else {}
153+
154+ if finish_reason := choice .get ("finish_reason" ):
155+ generation_info ["finish_reason" ] = finish_reason
156+ if model_name := chunk .get ("model" ):
157+ generation_info ["model_name" ] = model_name
158+ if system_fingerprint := chunk .get ("system_fingerprint" ):
159+ generation_info ["system_fingerprint" ] = system_fingerprint
160+
161+ logprobs = choice .get ("logprobs" )
162+ if logprobs :
163+ generation_info ["logprobs" ] = logprobs
164+
165+ if usage_metadata and isinstance (message_chunk , AIMessageChunk ):
166+ message_chunk .usage_metadata = usage_metadata
167+
168+ generation_chunk = ChatGenerationChunk (
169+ message = message_chunk , generation_info = generation_info or None
170+ )
171+ return generation_chunk
143172
144173 def invoke (
145174 self ,
146175 input : LanguageModelInput ,
147176 config : Optional [RunnableConfig ] = None ,
148177 * ,
149- stop : Optional [List [str ]] = None ,
178+ stop : Optional [list [str ]] = None ,
150179 ** kwargs : Any ,
151180 ) -> BaseMessage :
152181 config = ensure_config (config )
153182 chat_result = cast (
154- ChatGeneration ,
183+ " ChatGeneration" ,
155184 self .generate_prompt (
156185 [self ._convert_input (input )],
157186 stop = stop ,
@@ -162,7 +191,9 @@ def invoke(
162191 run_id = config .pop ("run_id" , None ),
163192 ** kwargs ,
164193 ).generations [0 ][0 ],
194+
165195 ).message
196+
166197 self .usage_metadata = chat_result .response_metadata [
167198 'token_usage' ] if 'token_usage' in chat_result .response_metadata else chat_result .usage_metadata
168199 return chat_result
0 commit comments