11import os
22import json
3- import boto3
4- from unittest .mock import patch
3+ import pytest
4+ from typing import Dict , Iterator
5+ from unittest .mock import patch , MagicMock
6+
57from langtrace_python_sdk import langtrace
68from examples .awsbedrock_examples .converse import use_converse , use_converse_stream
79
8- def mock_invoke_model (* args , ** kwargs ):
10+ @pytest .fixture
11+ def mock_env ():
12+ """Provide mock environment variables for testing."""
13+ with patch .dict (os .environ , {
14+ "LANGTRACE_API_KEY" : "test_key" ,
15+ "AWS_ACCESS_KEY_ID" : "test_aws_key" ,
16+ "AWS_SECRET_ACCESS_KEY" : "test_aws_secret"
17+ }):
18+ yield
19+
20+ @pytest .fixture
21+ def mock_bedrock_client ():
22+ """Provide a mocked AWS Bedrock client."""
23+ with patch ("boto3.client" ) as mock_client :
24+ mock_instance = mock_client .return_value
25+ mock_instance .invoke_model .side_effect = mock_invoke_model
26+ mock_instance .invoke_model_with_response_stream .side_effect = mock_invoke_model_with_response_stream
27+ yield mock_instance
28+
29+ def mock_invoke_model (* args , ** kwargs ) -> Dict :
930 """Mock for standard completion with vendor attribute verification."""
1031 # Verify the request body contains all expected attributes
1132 body = json .loads (kwargs .get ('body' , '{}' ))
@@ -19,19 +40,17 @@ def mock_invoke_model(*args, **kwargs):
1940 assert kwargs .get ('modelId' ) == "anthropic.claude-3-haiku-20240307-v1:0" , f"Incorrect modelId: { kwargs .get ('modelId' )} "
2041
2142 mock_response = {
22- "completion" : "Mocked response for testing with vendor attributes" ,
23- "stop_reason" : "stop_sequence" ,
24- "usage" : {
25- "input_tokens" : 10 ,
26- "output_tokens" : 20 ,
27- "total_tokens" : 30
43+ "output" : {
44+ "message" : {
45+ "content" : [{"text" : "Mocked response for testing with vendor attributes" }]
46+ }
2847 }
2948 }
3049 return {
3150 'body' : json .dumps (mock_response ).encode ()
3251 }
3352
34- def mock_invoke_model_with_response_stream (* args , ** kwargs ):
53+ def mock_invoke_model_with_response_stream (* args , ** kwargs ) -> Dict :
3554 """Mock for streaming completion with vendor attribute verification."""
3655 # Verify the request body contains all expected attributes
3756 body = json .loads (kwargs .get ('body' , '{}' ))
@@ -48,55 +67,46 @@ def mock_invoke_model_with_response_stream(*args, **kwargs):
4867 {
4968 'chunk' : {
5069 'bytes' : json .dumps ({
51- "completion" : "Streaming chunk 1" ,
52- "stop_reason" : None
70+ "output" : {
71+ "message" : {
72+ "content" : [{"text" : "Streaming chunk 1" }]
73+ }
74+ }
5375 }).encode ()
5476 }
5577 },
5678 {
5779 'chunk' : {
5880 'bytes' : json .dumps ({
59- "completion" : "Streaming chunk 2" ,
60- "stop_reason" : "stop_sequence"
81+ "output" : {
82+ "message" : {
83+ "content" : [{"text" : "Streaming chunk 2" }]
84+ }
85+ }
6186 }).encode ()
6287 }
6388 }
6489 ]
6590 return {'body' : chunks }
6691
67- def run_test ():
68- """Run tests for both standard and streaming completion."""
69- # Initialize Langtrace with API key from environment
70- langtrace .init (api_key = os .environ ["LANGTRACE_API_KEY" ])
71-
72- with patch ("boto3.client" ) as mock_client :
73- mock_instance = mock_client .return_value
74- mock_instance .invoke_model = mock_invoke_model
75- mock_instance .invoke_model_with_response_stream = mock_invoke_model_with_response_stream
76-
77- print ("\n Testing AWS Bedrock instrumentation..." )
78-
79- try :
80- # Test standard completion
81- print ("\n Testing standard completion..." )
82- response = use_converse ("Tell me about OpenTelemetry" )
83- print (f"Standard completion response: { response } " )
84- print ("✓ Standard completion test passed with vendor attributes" )
85-
86- # Test streaming completion
87- print ("\n Testing streaming completion..." )
88- chunks = []
89- for chunk in use_converse_stream ("What is distributed tracing?" ):
90- chunks .append (chunk )
91- print (f"Streaming chunk: { chunk } " )
92- assert len (chunks ) == 2 , f"Expected 2 chunks, got { len (chunks )} "
93- print (f"✓ Streaming completion test passed with { len (chunks )} chunks" )
92+ @pytest .mark .usefixtures ("mock_env" )
93+ class TestAWSBedrock :
94+ """Test suite for AWS Bedrock instrumentation."""
9495
95- print ("\n ✓ All tests completed successfully!" )
96- except AssertionError as e :
97- print (f"\n ❌ Test failed: { str (e )} " )
98- raise
96+ def test_standard_completion (self , mock_bedrock_client ):
97+ """Test standard completion with mocked AWS client."""
98+ response = use_converse ("Tell me about OpenTelemetry" )
99+ assert response is not None
100+ content = response .get ('output' , {}).get ('message' , {}).get ('content' , [])
101+ assert content , "Response content should not be empty"
102+ assert isinstance (content [0 ].get ('text' ), str ), "Response text should be a string"
99103
104+ def test_streaming_completion (self , mock_bedrock_client ):
105+ """Test streaming completion with mocked AWS client."""
106+ chunks = list (use_converse_stream ("What is distributed tracing?" ))
107+ assert len (chunks ) == 2 , f"Expected 2 chunks, got { len (chunks )} "
100108
101- if __name__ == "__main__" :
102- run_test ()
109+ for chunk in chunks :
110+ content = chunk .get ('output' , {}).get ('message' , {}).get ('content' , [])
111+ assert content , "Chunk content should not be empty"
112+ assert isinstance (content [0 ].get ('text' ), str ), "Chunk text should be a string"
0 commit comments