@@ -65,14 +65,15 @@ def __init__(
6565 self ._message = None
6666 self ._content_block = {}
6767 self ._record_message = False
68+ self ._ended = False
6869
6970 def __iter__ (self ):
7071 try :
7172 for event in self .__wrapped__ :
7273 self ._process_event (event )
7374 yield event
7475 except EventStreamError as exc :
75- self ._stream_error_callback (exc )
76+ self ._handle_stream_error (exc )
7677 raise
7778
7879 def _process_event (self , event ):
@@ -133,15 +134,22 @@ def _process_event(self, event):
133134
134135 if output_tokens := usage .get ("outputTokens" ):
135136 self ._response ["usage" ]["outputTokens" ] = output_tokens
136-
137- self ._stream_done_callback (self ._response )
137+ self ._complete_stream (self ._response )
138138
139139 return
140140
141141 def close (self ):
142142 self .__wrapped__ .close ()
143143 # Treat the stream as done to ensure the span end.
144- self ._stream_done_callback (self ._response )
144+ self ._complete_stream (self ._response )
145+
146+ def _complete_stream (self , response ):
147+ self ._stream_done_callback (response , self ._ended )
148+ self ._ended = True
149+
150+ def _handle_stream_error (self , exc ):
151+ self ._stream_error_callback (exc , self ._ended )
152+ self ._ended = True
145153
146154
147155# pylint: disable=abstract-method
@@ -168,19 +176,28 @@ def __init__(
168176 self ._content_block = {}
169177 self ._tool_json_input_buf = ""
170178 self ._record_message = False
179+ self ._ended = False
171180
172181 def close (self ):
173182 self .__wrapped__ .close ()
174183 # Treat the stream as done to ensure the span end.
175- self ._stream_done_callback (self ._response )
184+ self ._stream_done_callback (self ._response , self ._ended )
185+
186+ def _complete_stream (self , response ):
187+ self ._stream_done_callback (response , self ._ended )
188+ self ._ended = True
189+
190+ def _handle_stream_error (self , exc ):
191+ self ._stream_error_callback (exc , self ._ended )
192+ self ._ended = True
176193
177194 def __iter__ (self ):
178195 try :
179196 for event in self .__wrapped__ :
180197 self ._process_event (event )
181198 yield event
182199 except EventStreamError as exc :
183- self ._stream_error_callback (exc )
200+ self ._handle_stream_error (exc )
184201 raise
185202
186203 def _process_event (self , event ):
@@ -223,7 +240,7 @@ def _process_amazon_titan_chunk(self, chunk):
223240 self ._response ["output" ] = {
224241 "message" : {"content" : [{"text" : chunk ["outputText" ]}]}
225242 }
226- self ._stream_done_callback (self ._response )
243+ self ._complete_stream (self ._response )
227244
228245 def _process_amazon_nova_chunk (self , chunk ):
229246 # pylint: disable=too-many-branches
@@ -293,7 +310,7 @@ def _process_amazon_nova_chunk(self, chunk):
293310 if output_tokens := usage .get ("outputTokens" ):
294311 self ._response ["usage" ]["outputTokens" ] = output_tokens
295312
296- self ._stream_done_callback (self ._response )
313+ self ._complete_stream (self ._response )
297314 return
298315
299316 def _process_anthropic_claude_chunk (self , chunk ):
@@ -365,7 +382,7 @@ def _process_anthropic_claude_chunk(self, chunk):
365382 self ._record_message = False
366383 self ._message = None
367384
368- self ._stream_done_callback (self ._response )
385+ self ._complete_stream (self ._response )
369386 return
370387
371388
0 commit comments