Skip to content

Commit 57ea56b

Browse files
authored
Merge pull request #57 from 3coins/fix-stop-reason-error
Fixes error when stop reason missing in response.
2 parents 9adc14d + 83bb6b0 commit 57ea56b

File tree

3 files changed

+144
-1
lines changed

3 files changed

+144
-1
lines changed

libs/aws/langchain_aws/llms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
Bedrock,
44
BedrockBase,
55
BedrockLLM,
6+
LLMInputOutputAdapter,
67
)
78
from langchain_aws.llms.sagemaker_endpoint import SagemakerEndpoint
89

@@ -11,5 +12,6 @@
1112
"Bedrock",
1213
"BedrockBase",
1314
"BedrockLLM",
15+
"LLMInputOutputAdapter",
1416
"SagemakerEndpoint",
1517
]

libs/aws/langchain_aws/llms/bedrock.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ def prepare_output(cls, provider: str, response: Any) -> dict:
257257
"completion_tokens": completion_tokens,
258258
"total_tokens": prompt_tokens + completion_tokens,
259259
},
260-
"stop_reason": response_body["stop_reason"],
260+
"stop_reason": response_body.get("stop_reason"),
261261
}
262262

263263
@classmethod

libs/aws/tests/unit_tests/llms/test_bedrock.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# type:ignore
2+
13
import json
24
from typing import AsyncGenerator, Dict
35
from unittest.mock import MagicMock, patch
@@ -7,6 +9,7 @@
79
from langchain_aws import BedrockLLM
810
from langchain_aws.llms.bedrock import (
911
ALTERNATION_ERROR,
12+
LLMInputOutputAdapter,
1013
_human_assistant_format,
1114
)
1215

@@ -306,3 +309,141 @@ async def test_bedrock_async_streaming_call() -> None:
306309
assert chunks[0] == "nice"
307310
assert chunks[1] == " to meet"
308311
assert chunks[2] == " you"
312+
313+
314+
@pytest.fixture
315+
def mistral_response():
316+
body = MagicMock()
317+
body.read.return_value = json.dumps(
318+
{"outputs": [{"text": "This is the Mistral output text."}]}
319+
).encode()
320+
response = dict(
321+
body=body,
322+
ResponseMetadata={
323+
"HTTPHeaders": {
324+
"x-amzn-bedrock-input-token-count": "18",
325+
"x-amzn-bedrock-output-token-count": "28",
326+
}
327+
},
328+
)
329+
330+
return response
331+
332+
333+
@pytest.fixture
334+
def cohere_response():
335+
body = MagicMock()
336+
body.read.return_value = json.dumps(
337+
{"generations": [{"text": "This is the Cohere output text."}]}
338+
).encode()
339+
response = dict(
340+
body=body,
341+
ResponseMetadata={
342+
"HTTPHeaders": {
343+
"x-amzn-bedrock-input-token-count": "12",
344+
"x-amzn-bedrock-output-token-count": "22",
345+
}
346+
},
347+
)
348+
return response
349+
350+
351+
@pytest.fixture
352+
def anthropic_response():
353+
body = MagicMock()
354+
body.read.return_value = json.dumps(
355+
{"completion": "This is the output text."}
356+
).encode()
357+
response = dict(
358+
body=body,
359+
ResponseMetadata={
360+
"HTTPHeaders": {
361+
"x-amzn-bedrock-input-token-count": "10",
362+
"x-amzn-bedrock-output-token-count": "20",
363+
}
364+
},
365+
)
366+
return response
367+
368+
369+
@pytest.fixture
370+
def ai21_response():
371+
body = MagicMock()
372+
body.read.return_value = json.dumps(
373+
{"completions": [{"data": {"text": "This is the AI21 output text."}}]}
374+
).encode()
375+
response = dict(
376+
body=body,
377+
ResponseMetadata={
378+
"HTTPHeaders": {
379+
"x-amzn-bedrock-input-token-count": "15",
380+
"x-amzn-bedrock-output-token-count": "25",
381+
}
382+
},
383+
)
384+
return response
385+
386+
387+
@pytest.fixture
388+
def response_with_stop_reason():
389+
body = MagicMock()
390+
body.read.return_value = json.dumps(
391+
{"completion": "This is the output text.", "stop_reason": "length"}
392+
).encode()
393+
response = dict(
394+
body=body,
395+
ResponseMetadata={
396+
"HTTPHeaders": {
397+
"x-amzn-bedrock-input-token-count": "10",
398+
"x-amzn-bedrock-output-token-count": "20",
399+
}
400+
},
401+
)
402+
return response
403+
404+
405+
def test_prepare_output_for_mistral(mistral_response):
406+
result = LLMInputOutputAdapter.prepare_output("mistral", mistral_response)
407+
assert result["text"] == "This is the Mistral output text."
408+
assert result["usage"]["prompt_tokens"] == 18
409+
assert result["usage"]["completion_tokens"] == 28
410+
assert result["usage"]["total_tokens"] == 46
411+
assert result["stop_reason"] is None
412+
413+
414+
def test_prepare_output_for_cohere(cohere_response):
415+
result = LLMInputOutputAdapter.prepare_output("cohere", cohere_response)
416+
assert result["text"] == "This is the Cohere output text."
417+
assert result["usage"]["prompt_tokens"] == 12
418+
assert result["usage"]["completion_tokens"] == 22
419+
assert result["usage"]["total_tokens"] == 34
420+
assert result["stop_reason"] is None
421+
422+
423+
def test_prepare_output_with_stop_reason(response_with_stop_reason):
424+
result = LLMInputOutputAdapter.prepare_output(
425+
"anthropic", response_with_stop_reason
426+
)
427+
assert result["text"] == "This is the output text."
428+
assert result["usage"]["prompt_tokens"] == 10
429+
assert result["usage"]["completion_tokens"] == 20
430+
assert result["usage"]["total_tokens"] == 30
431+
assert result["stop_reason"] == "length"
432+
433+
434+
def test_prepare_output_for_anthropic(anthropic_response):
435+
result = LLMInputOutputAdapter.prepare_output("anthropic", anthropic_response)
436+
assert result["text"] == "This is the output text."
437+
assert result["usage"]["prompt_tokens"] == 10
438+
assert result["usage"]["completion_tokens"] == 20
439+
assert result["usage"]["total_tokens"] == 30
440+
assert result["stop_reason"] is None
441+
442+
443+
def test_prepare_output_for_ai21(ai21_response):
444+
result = LLMInputOutputAdapter.prepare_output("ai21", ai21_response)
445+
assert result["text"] == "This is the AI21 output text."
446+
assert result["usage"]["prompt_tokens"] == 15
447+
assert result["usage"]["completion_tokens"] == 25
448+
assert result["usage"]["total_tokens"] == 40
449+
assert result["stop_reason"] is None

0 commit comments

Comments
 (0)