|
| 1 | +# type:ignore |
| 2 | + |
1 | 3 | import json
|
2 | 4 | from typing import AsyncGenerator, Dict
|
3 | 5 | from unittest.mock import MagicMock, patch
|
|
7 | 9 | from langchain_aws import BedrockLLM
|
8 | 10 | from langchain_aws.llms.bedrock import (
|
9 | 11 | ALTERNATION_ERROR,
|
| 12 | + LLMInputOutputAdapter, |
10 | 13 | _human_assistant_format,
|
11 | 14 | )
|
12 | 15 |
|
@@ -306,3 +309,141 @@ async def test_bedrock_async_streaming_call() -> None:
|
306 | 309 | assert chunks[0] == "nice"
|
307 | 310 | assert chunks[1] == " to meet"
|
308 | 311 | assert chunks[2] == " you"
|
| 312 | + |
| 313 | + |
| 314 | +@pytest.fixture |
| 315 | +def mistral_response(): |
| 316 | + body = MagicMock() |
| 317 | + body.read.return_value = json.dumps( |
| 318 | + {"outputs": [{"text": "This is the Mistral output text."}]} |
| 319 | + ).encode() |
| 320 | + response = dict( |
| 321 | + body=body, |
| 322 | + ResponseMetadata={ |
| 323 | + "HTTPHeaders": { |
| 324 | + "x-amzn-bedrock-input-token-count": "18", |
| 325 | + "x-amzn-bedrock-output-token-count": "28", |
| 326 | + } |
| 327 | + }, |
| 328 | + ) |
| 329 | + |
| 330 | + return response |
| 331 | + |
| 332 | + |
| 333 | +@pytest.fixture |
| 334 | +def cohere_response(): |
| 335 | + body = MagicMock() |
| 336 | + body.read.return_value = json.dumps( |
| 337 | + {"generations": [{"text": "This is the Cohere output text."}]} |
| 338 | + ).encode() |
| 339 | + response = dict( |
| 340 | + body=body, |
| 341 | + ResponseMetadata={ |
| 342 | + "HTTPHeaders": { |
| 343 | + "x-amzn-bedrock-input-token-count": "12", |
| 344 | + "x-amzn-bedrock-output-token-count": "22", |
| 345 | + } |
| 346 | + }, |
| 347 | + ) |
| 348 | + return response |
| 349 | + |
| 350 | + |
| 351 | +@pytest.fixture |
| 352 | +def anthropic_response(): |
| 353 | + body = MagicMock() |
| 354 | + body.read.return_value = json.dumps( |
| 355 | + {"completion": "This is the output text."} |
| 356 | + ).encode() |
| 357 | + response = dict( |
| 358 | + body=body, |
| 359 | + ResponseMetadata={ |
| 360 | + "HTTPHeaders": { |
| 361 | + "x-amzn-bedrock-input-token-count": "10", |
| 362 | + "x-amzn-bedrock-output-token-count": "20", |
| 363 | + } |
| 364 | + }, |
| 365 | + ) |
| 366 | + return response |
| 367 | + |
| 368 | + |
| 369 | +@pytest.fixture |
| 370 | +def ai21_response(): |
| 371 | + body = MagicMock() |
| 372 | + body.read.return_value = json.dumps( |
| 373 | + {"completions": [{"data": {"text": "This is the AI21 output text."}}]} |
| 374 | + ).encode() |
| 375 | + response = dict( |
| 376 | + body=body, |
| 377 | + ResponseMetadata={ |
| 378 | + "HTTPHeaders": { |
| 379 | + "x-amzn-bedrock-input-token-count": "15", |
| 380 | + "x-amzn-bedrock-output-token-count": "25", |
| 381 | + } |
| 382 | + }, |
| 383 | + ) |
| 384 | + return response |
| 385 | + |
| 386 | + |
| 387 | +@pytest.fixture |
| 388 | +def response_with_stop_reason(): |
| 389 | + body = MagicMock() |
| 390 | + body.read.return_value = json.dumps( |
| 391 | + {"completion": "This is the output text.", "stop_reason": "length"} |
| 392 | + ).encode() |
| 393 | + response = dict( |
| 394 | + body=body, |
| 395 | + ResponseMetadata={ |
| 396 | + "HTTPHeaders": { |
| 397 | + "x-amzn-bedrock-input-token-count": "10", |
| 398 | + "x-amzn-bedrock-output-token-count": "20", |
| 399 | + } |
| 400 | + }, |
| 401 | + ) |
| 402 | + return response |
| 403 | + |
| 404 | + |
| 405 | +def test_prepare_output_for_mistral(mistral_response): |
| 406 | + result = LLMInputOutputAdapter.prepare_output("mistral", mistral_response) |
| 407 | + assert result["text"] == "This is the Mistral output text." |
| 408 | + assert result["usage"]["prompt_tokens"] == 18 |
| 409 | + assert result["usage"]["completion_tokens"] == 28 |
| 410 | + assert result["usage"]["total_tokens"] == 46 |
| 411 | + assert result["stop_reason"] is None |
| 412 | + |
| 413 | + |
| 414 | +def test_prepare_output_for_cohere(cohere_response): |
| 415 | + result = LLMInputOutputAdapter.prepare_output("cohere", cohere_response) |
| 416 | + assert result["text"] == "This is the Cohere output text." |
| 417 | + assert result["usage"]["prompt_tokens"] == 12 |
| 418 | + assert result["usage"]["completion_tokens"] == 22 |
| 419 | + assert result["usage"]["total_tokens"] == 34 |
| 420 | + assert result["stop_reason"] is None |
| 421 | + |
| 422 | + |
| 423 | +def test_prepare_output_with_stop_reason(response_with_stop_reason): |
| 424 | + result = LLMInputOutputAdapter.prepare_output( |
| 425 | + "anthropic", response_with_stop_reason |
| 426 | + ) |
| 427 | + assert result["text"] == "This is the output text." |
| 428 | + assert result["usage"]["prompt_tokens"] == 10 |
| 429 | + assert result["usage"]["completion_tokens"] == 20 |
| 430 | + assert result["usage"]["total_tokens"] == 30 |
| 431 | + assert result["stop_reason"] == "length" |
| 432 | + |
| 433 | + |
| 434 | +def test_prepare_output_for_anthropic(anthropic_response): |
| 435 | + result = LLMInputOutputAdapter.prepare_output("anthropic", anthropic_response) |
| 436 | + assert result["text"] == "This is the output text." |
| 437 | + assert result["usage"]["prompt_tokens"] == 10 |
| 438 | + assert result["usage"]["completion_tokens"] == 20 |
| 439 | + assert result["usage"]["total_tokens"] == 30 |
| 440 | + assert result["stop_reason"] is None |
| 441 | + |
| 442 | + |
| 443 | +def test_prepare_output_for_ai21(ai21_response): |
| 444 | + result = LLMInputOutputAdapter.prepare_output("ai21", ai21_response) |
| 445 | + assert result["text"] == "This is the AI21 output text." |
| 446 | + assert result["usage"]["prompt_tokens"] == 15 |
| 447 | + assert result["usage"]["completion_tokens"] == 25 |
| 448 | + assert result["usage"]["total_tokens"] == 40 |
| 449 | + assert result["stop_reason"] is None |
0 commit comments