35
35
from langchain_core .tools import BaseTool
36
36
37
37
from langchain_aws .function_calling import convert_to_anthropic_tool , get_system_message
38
- from langchain_aws .llms .bedrock import BedrockBase
38
+ from langchain_aws .llms .bedrock import (
39
+ BedrockBase ,
40
+ _combine_generation_info_for_llm_result ,
41
+ )
39
42
from langchain_aws .utils import (
40
43
get_num_tokens_anthropic ,
41
44
get_token_ids_anthropic ,
@@ -383,7 +386,13 @@ def _stream(
383
386
** kwargs ,
384
387
):
385
388
delta = chunk .text
386
- yield ChatGenerationChunk (message = AIMessageChunk (content = delta ))
389
+ yield ChatGenerationChunk (
390
+ message = AIMessageChunk (
391
+ content = delta , response_metadata = chunk .generation_info
392
+ )
393
+ if chunk .generation_info is not None
394
+ else AIMessageChunk (content = delta )
395
+ )
387
396
388
397
def _generate (
389
398
self ,
@@ -393,11 +402,18 @@ def _generate(
393
402
** kwargs : Any ,
394
403
) -> ChatResult :
395
404
completion = ""
396
- llm_output : Dict [str , Any ] = {"model_id" : self .model_id }
397
- usage_info : Dict [str , Any ] = {}
405
+ llm_output : Dict [str , Any ] = {}
406
+ provider_stop_reason_code = self .provider_stop_reason_key_map .get (
407
+ self ._get_provider (), "stop_reason"
408
+ )
398
409
if self .streaming :
410
+ response_metadata : List [Dict [str , Any ]] = []
399
411
for chunk in self ._stream (messages , stop , run_manager , ** kwargs ):
400
412
completion += chunk .text
413
+ response_metadata .append (chunk .message .response_metadata )
414
+ llm_output = _combine_generation_info_for_llm_result (
415
+ response_metadata , provider_stop_reason_code
416
+ )
401
417
else :
402
418
provider = self ._get_provider ()
403
419
prompt , system , formatted_messages = None , None , None
@@ -420,7 +436,7 @@ def _generate(
420
436
if stop :
421
437
params ["stop_sequences" ] = stop
422
438
423
- completion , usage_info = self ._prepare_input_and_invoke (
439
+ completion , llm_output = self ._prepare_input_and_invoke (
424
440
prompt = prompt ,
425
441
stop = stop ,
426
442
run_manager = run_manager ,
@@ -429,14 +445,11 @@ def _generate(
429
445
** params ,
430
446
)
431
447
432
- llm_output ["usage" ] = usage_info
433
-
448
+ llm_output ["model_id" ] = self .model_id
434
449
return ChatResult (
435
450
generations = [
436
451
ChatGeneration (
437
- message = AIMessage (
438
- content = completion , additional_kwargs = {"usage" : usage_info }
439
- )
452
+ message = AIMessage (content = completion , additional_kwargs = llm_output )
440
453
)
441
454
],
442
455
llm_output = llm_output ,
@@ -447,7 +460,7 @@ def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
447
460
final_output = {}
448
461
for output in llm_outputs :
449
462
output = output or {}
450
- usage = output .pop ("usage" , {})
463
+ usage = output .get ("usage" , {})
451
464
for token_type , token_count in usage .items ():
452
465
final_usage [token_type ] += token_count
453
466
final_output .update (output )
0 commit comments