File tree Expand file tree Collapse file tree 3 files changed +27
-22
lines changed
tests/integration_tests/chat_models Expand file tree Collapse file tree 3 files changed +27
-22
lines changed Original file line number Diff line number Diff line change @@ -565,20 +565,16 @@ def _generate(
565
565
)
566
566
else :
567
567
usage_metadata = None
568
+
568
569
llm_output ["model_id" ] = self .model_id
569
- if len (tool_calls ) > 0 :
570
- msg = AIMessage (
571
- content = completion ,
572
- additional_kwargs = llm_output ,
573
- tool_calls = cast (List [ToolCall ], tool_calls ),
574
- usage_metadata = usage_metadata ,
575
- )
576
- else :
577
- msg = AIMessage (
578
- content = completion ,
579
- additional_kwargs = llm_output ,
580
- usage_metadata = usage_metadata ,
581
- )
570
+
571
+ msg = AIMessage (
572
+ content = completion ,
573
+ additional_kwargs = llm_output ,
574
+ tool_calls = cast (List [ToolCall ], tool_calls ),
575
+ usage_metadata = usage_metadata ,
576
+ )
577
+
582
578
return ChatResult (
583
579
generations = [
584
580
ChatGeneration (
Original file line number Diff line number Diff line change @@ -117,7 +117,11 @@ def _stream_response_to_generation_chunk(
117
117
return None
118
118
else :
119
119
# chunk obj format varies with provider
120
- generation_info = {k : v for k , v in stream_response .items () if k != output_key }
120
+ generation_info = {
121
+ k : v
122
+ for k , v in stream_response .items ()
123
+ if k not in [output_key , "prompt_token_count" , "generation_token_count" ]
124
+ }
121
125
return GenerationChunk (
122
126
text = (
123
127
stream_response [output_key ]
@@ -347,6 +351,10 @@ def prepare_output_stream(
347
351
yield _get_invocation_metrics_chunk (chunk_obj )
348
352
return
349
353
354
+ elif provider == "meta" and chunk_obj .get ("stop_reason" , "" ) == "stop" :
355
+ yield _get_invocation_metrics_chunk (chunk_obj )
356
+ return
357
+
350
358
elif messages_api and (chunk_obj .get ("type" ) == "message_stop" ):
351
359
yield _get_invocation_metrics_chunk (chunk_obj )
352
360
return
Original file line number Diff line number Diff line change @@ -84,17 +84,18 @@ def test_chat_bedrock_streaming() -> None:
84
84
@pytest .mark .scheduled
85
85
def test_chat_bedrock_streaming_llama3 () -> None :
86
86
"""Test that streaming correctly invokes on_llm_new_token callback."""
87
- callback_handler = FakeCallbackHandler ()
88
87
chat = ChatBedrock ( # type: ignore[call-arg]
89
- model_id = "meta.llama3-8b-instruct-v1:0" ,
90
- streaming = True ,
91
- callbacks = [callback_handler ],
92
- verbose = True ,
88
+ model_id = "meta.llama3-8b-instruct-v1:0"
93
89
)
94
90
message = HumanMessage (content = "Hello" )
95
- response = chat ([message ])
96
- assert callback_handler .llm_streams > 0
97
- assert isinstance (response , BaseMessage )
91
+
92
+ response = AIMessageChunk (content = "" )
93
+ for chunk in chat .stream ([message ]):
94
+ response += chunk # type: ignore[assignment]
95
+
96
+ assert response .content
97
+ assert response .response_metadata
98
+ assert response .usage_metadata
98
99
99
100
100
101
@pytest .mark .scheduled
You can’t perform that action at this time.
0 commit comments