Skip to content

Commit e38c4c4

Browse files
test(bedrock): update test runner with proper fixtures and env handling
Co-Authored-By: [email protected] <[email protected]>
1 parent cda6d3c commit e38c4c4

File tree

1 file changed

+57
-47
lines changed

1 file changed

+57
-47
lines changed
Lines changed: 57 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,32 @@
11
import os
22
import json
3-
import boto3
4-
from unittest.mock import patch
3+
import pytest
4+
from typing import Dict, Iterator
5+
from unittest.mock import patch, MagicMock
6+
57
from langtrace_python_sdk import langtrace
68
from examples.awsbedrock_examples.converse import use_converse, use_converse_stream
79

8-
def mock_invoke_model(*args, **kwargs):
10+
@pytest.fixture
11+
def mock_env():
12+
"""Provide mock environment variables for testing."""
13+
with patch.dict(os.environ, {
14+
"LANGTRACE_API_KEY": "test_key",
15+
"AWS_ACCESS_KEY_ID": "test_aws_key",
16+
"AWS_SECRET_ACCESS_KEY": "test_aws_secret"
17+
}):
18+
yield
19+
20+
@pytest.fixture
21+
def mock_bedrock_client():
22+
"""Provide a mocked AWS Bedrock client."""
23+
with patch("boto3.client") as mock_client:
24+
mock_instance = mock_client.return_value
25+
mock_instance.invoke_model.side_effect = mock_invoke_model
26+
mock_instance.invoke_model_with_response_stream.side_effect = mock_invoke_model_with_response_stream
27+
yield mock_instance
28+
29+
def mock_invoke_model(*args, **kwargs) -> Dict:
930
"""Mock for standard completion with vendor attribute verification."""
1031
# Verify the request body contains all expected attributes
1132
body = json.loads(kwargs.get('body', '{}'))
@@ -19,19 +40,17 @@ def mock_invoke_model(*args, **kwargs):
1940
assert kwargs.get('modelId') == "anthropic.claude-3-haiku-20240307-v1:0", f"Incorrect modelId: {kwargs.get('modelId')}"
2041

2142
mock_response = {
22-
"completion": "Mocked response for testing with vendor attributes",
23-
"stop_reason": "stop_sequence",
24-
"usage": {
25-
"input_tokens": 10,
26-
"output_tokens": 20,
27-
"total_tokens": 30
43+
"output": {
44+
"message": {
45+
"content": [{"text": "Mocked response for testing with vendor attributes"}]
46+
}
2847
}
2948
}
3049
return {
3150
'body': json.dumps(mock_response).encode()
3251
}
3352

34-
def mock_invoke_model_with_response_stream(*args, **kwargs):
53+
def mock_invoke_model_with_response_stream(*args, **kwargs) -> Dict:
3554
"""Mock for streaming completion with vendor attribute verification."""
3655
# Verify the request body contains all expected attributes
3756
body = json.loads(kwargs.get('body', '{}'))
@@ -48,55 +67,46 @@ def mock_invoke_model_with_response_stream(*args, **kwargs):
4867
{
4968
'chunk': {
5069
'bytes': json.dumps({
51-
"completion": "Streaming chunk 1",
52-
"stop_reason": None
70+
"output": {
71+
"message": {
72+
"content": [{"text": "Streaming chunk 1"}]
73+
}
74+
}
5375
}).encode()
5476
}
5577
},
5678
{
5779
'chunk': {
5880
'bytes': json.dumps({
59-
"completion": "Streaming chunk 2",
60-
"stop_reason": "stop_sequence"
81+
"output": {
82+
"message": {
83+
"content": [{"text": "Streaming chunk 2"}]
84+
}
85+
}
6186
}).encode()
6287
}
6388
}
6489
]
6590
return {'body': chunks}
6691

67-
def run_test():
68-
"""Run tests for both standard and streaming completion."""
69-
# Initialize Langtrace with API key from environment
70-
langtrace.init(api_key=os.environ["LANGTRACE_API_KEY"])
71-
72-
with patch("boto3.client") as mock_client:
73-
mock_instance = mock_client.return_value
74-
mock_instance.invoke_model = mock_invoke_model
75-
mock_instance.invoke_model_with_response_stream = mock_invoke_model_with_response_stream
76-
77-
print("\nTesting AWS Bedrock instrumentation...")
78-
79-
try:
80-
# Test standard completion
81-
print("\nTesting standard completion...")
82-
response = use_converse("Tell me about OpenTelemetry")
83-
print(f"Standard completion response: {response}")
84-
print("✓ Standard completion test passed with vendor attributes")
85-
86-
# Test streaming completion
87-
print("\nTesting streaming completion...")
88-
chunks = []
89-
for chunk in use_converse_stream("What is distributed tracing?"):
90-
chunks.append(chunk)
91-
print(f"Streaming chunk: {chunk}")
92-
assert len(chunks) == 2, f"Expected 2 chunks, got {len(chunks)}"
93-
print(f"✓ Streaming completion test passed with {len(chunks)} chunks")
92+
@pytest.mark.usefixtures("mock_env")
93+
class TestAWSBedrock:
94+
"""Test suite for AWS Bedrock instrumentation."""
9495

95-
print("\n✓ All tests completed successfully!")
96-
except AssertionError as e:
97-
print(f"\n❌ Test failed: {str(e)}")
98-
raise
96+
def test_standard_completion(self, mock_bedrock_client):
97+
"""Test standard completion with mocked AWS client."""
98+
response = use_converse("Tell me about OpenTelemetry")
99+
assert response is not None
100+
content = response.get('output', {}).get('message', {}).get('content', [])
101+
assert content, "Response content should not be empty"
102+
assert isinstance(content[0].get('text'), str), "Response text should be a string"
99103

104+
def test_streaming_completion(self, mock_bedrock_client):
105+
"""Test streaming completion with mocked AWS client."""
106+
chunks = list(use_converse_stream("What is distributed tracing?"))
107+
assert len(chunks) == 2, f"Expected 2 chunks, got {len(chunks)}"
100108

101-
if __name__ == "__main__":
102-
run_test()
109+
for chunk in chunks:
110+
content = chunk.get('output', {}).get('message', {}).get('content', [])
111+
assert content, "Chunk content should not be empty"
112+
assert isinstance(content[0].get('text'), str), "Chunk text should be a string"

0 commit comments

Comments
 (0)