1- from typing import Dict , Optional , Any , Iterator , cast , Mapping
1+ from collections .abc import Iterator , Mapping
2+ from typing import Any , cast
23
34from langchain_core .language_models import LanguageModelInput
4- from langchain_core .messages import BaseMessage , BaseMessageChunk , HumanMessageChunk , AIMessageChunk , \
5- SystemMessageChunk , FunctionMessageChunk , ChatMessageChunk
5+ from langchain_core .messages import (
6+ AIMessageChunk ,
7+ BaseMessage ,
8+ BaseMessageChunk ,
9+ ChatMessageChunk ,
10+ FunctionMessageChunk ,
11+ HumanMessageChunk ,
12+ SystemMessageChunk ,
13+ )
614from langchain_core .messages .ai import UsageMetadata
7- from langchain_core .messages .tool import tool_call_chunk , ToolMessageChunk
15+ from langchain_core .messages .tool import ToolMessageChunk , tool_call_chunk
816from langchain_core .outputs import ChatGenerationChunk
17+ from langchain_core .outputs .chat_generation import ChatGeneration
918from langchain_core .runnables import RunnableConfig , ensure_config
1019from langchain_openai import ChatOpenAI
1120from langchain_openai .chat_models .base import _create_usage_metadata
@@ -75,7 +84,7 @@ class BaseChatOpenAI(ChatOpenAI):
7584
7685 # custom_get_token_ids = custom_get_token_ids
7786
78- def get_last_generation_info (self ) -> Optional [ Dict [ str , Any ]] :
87+ def get_last_generation_info (self ) -> dict [ str , Any ] | None :
7988 return self .usage_metadata
8089
8190 def _stream (self , * args : Any , ** kwargs : Any ) -> Iterator [ChatGenerationChunk ]:
@@ -86,11 +95,11 @@ def _stream(self, *args: Any, **kwargs: Any) -> Iterator[ChatGenerationChunk]:
8695 yield chunk
8796
8897 def _convert_chunk_to_generation_chunk (
89- self ,
90- chunk : dict ,
91- default_chunk_class : type ,
92- base_generation_info : Optional [ dict ] ,
93- ) -> Optional [ ChatGenerationChunk ] :
98+ self ,
99+ chunk : dict ,
100+ default_chunk_class : type ,
101+ base_generation_info : dict | None ,
102+ ) -> ChatGenerationChunk | None :
94103 if chunk .get ("type" ) == "content.delta" : # from beta.chat.completions.stream
95104 return None
96105 token_usage = chunk .get ("usage" )
@@ -100,8 +109,10 @@ def _convert_chunk_to_generation_chunk(
100109 or chunk .get ("chunk" , {}).get ("choices" , [])
101110 )
102111
103- usage_metadata : Optional [UsageMetadata ] = (
104- _create_usage_metadata (token_usage ) if token_usage and token_usage .get ("prompt_tokens" ) else None
112+ usage_metadata : UsageMetadata | None = (
113+ _create_usage_metadata (token_usage )
114+ if token_usage and token_usage .get ("prompt_tokens" )
115+ else None
105116 )
106117 if len (choices ) == 0 :
107118 # logprobs is implicitly None
@@ -139,16 +150,16 @@ def _convert_chunk_to_generation_chunk(
139150 return generation_chunk
140151
141152 def invoke (
142- self ,
143- input : LanguageModelInput ,
144- config : Optional [ RunnableConfig ] = None ,
145- * ,
146- stop : Optional [ list [str ]] = None ,
147- ** kwargs : Any ,
153+ self ,
154+ input : LanguageModelInput ,
155+ config : RunnableConfig | None = None ,
156+ * ,
157+ stop : list [str ] | None = None ,
158+ ** kwargs : Any ,
148159 ) -> BaseMessage :
149160 config = ensure_config (config )
150161 chat_result = cast (
151- " ChatGeneration" ,
162+ ChatGeneration ,
152163 self .generate_prompt (
153164 [self ._convert_input (input )],
154165 stop = stop ,
@@ -159,7 +170,6 @@ def invoke(
159170 run_id = config .pop ("run_id" , None ),
160171 ** kwargs ,
161172 ).generations [0 ][0 ],
162-
163173 ).message
164174
165175 self .usage_metadata = chat_result .response_metadata [
0 commit comments