@@ -48,30 +48,42 @@ def wrapper(original_method):
4848 @wraps (original_method )
4949 def wrapped_method (* args , ** kwargs ):
5050 service_provider = SERVICE_PROVIDERS ["AWS_BEDROCK" ]
51-
51+ print ( "Here's the kwargs: " , kwargs )
5252 input_content = [
5353 {
54- ' role' : message .get (' role' , ' user' ),
55- ' content' : message .get (' content' , [])[0 ].get (' text' , "" )
54+ " role" : message .get (" role" , " user" ),
55+ " content" : message .get (" content" , [])[0 ].get (" text" , "" ),
5656 }
57- for message in kwargs .get (' messages' , [])
57+ for message in kwargs .get (" messages" , [])
5858 ]
59-
59+
6060 span_attributes = {
61- ** get_langtrace_attributes (version , service_provider , vendor_type = "framework" ),
62- ** get_llm_request_attributes (kwargs , operation_name = operation_name , prompts = input_content ),
61+ ** get_langtrace_attributes (
62+ version , service_provider , vendor_type = "framework"
63+ ),
64+ ** get_llm_request_attributes (
65+ kwargs , operation_name = operation_name , prompts = input_content
66+ ),
6367 ** get_llm_url (args [0 ] if args else None ),
6468 SpanAttributes .LLM_PATH : APIS [api_name ]["ENDPOINT" ],
6569 ** get_extra_attributes (),
6670 }
6771
6872 if api_name == "CONVERSE" :
69- span_attributes .update ({
70- SpanAttributes .LLM_REQUEST_MODEL : kwargs .get ('modelId' ),
71- SpanAttributes .LLM_REQUEST_MAX_TOKENS : kwargs .get ('inferenceConfig' , {}).get ('maxTokens' ),
72- SpanAttributes .LLM_REQUEST_TEMPERATURE : kwargs .get ('inferenceConfig' , {}).get ('temperature' ),
73- SpanAttributes .LLM_REQUEST_TOP_P : kwargs .get ('inferenceConfig' , {}).get ('top_p' ),
74- })
73+ span_attributes .update (
74+ {
75+ SpanAttributes .LLM_REQUEST_MODEL : kwargs .get ("modelId" ),
76+ SpanAttributes .LLM_REQUEST_MAX_TOKENS : kwargs .get (
77+ "inferenceConfig" , {}
78+ ).get ("maxTokens" ),
79+ SpanAttributes .LLM_REQUEST_TEMPERATURE : kwargs .get (
80+ "inferenceConfig" , {}
81+ ).get ("temperature" ),
82+ SpanAttributes .LLM_REQUEST_TOP_P : kwargs .get (
83+ "inferenceConfig" , {}
84+ ).get ("top_p" ),
85+ }
86+ )
7587
7688 attributes = LLMSpanAttributes (** span_attributes )
7789
@@ -92,20 +104,22 @@ def wrapped_method(*args, **kwargs):
92104 raise err
93105
94106 return wrapped_method
107+
95108 return wrapper
109+
96110 return decorator
97111
98112
99113converse = traced_aws_bedrock_call ("CONVERSE" , "converse" )
114+ invoke_model = traced_aws_bedrock_call ("INVOKE_MODEL" , "invoke_model" )
100115
101116
102117def converse_stream (original_method , version , tracer ):
103118 def traced_method (wrapped , instance , args , kwargs ):
104119 service_provider = SERVICE_PROVIDERS ["AWS_BEDROCK" ]
105-
120+
106121 span_attributes = {
107- ** get_langtrace_attributes
108- (version , service_provider , vendor_type = "llm" ),
122+ ** get_langtrace_attributes (version , service_provider , vendor_type = "llm" ),
109123 ** get_llm_request_attributes (kwargs ),
110124 ** get_llm_url (instance ),
111125 SpanAttributes .LLM_PATH : APIS ["CONVERSE_STREAM" ]["ENDPOINT" ],
@@ -129,29 +143,87 @@ def traced_method(wrapped, instance, args, kwargs):
129143 span .record_exception (err )
130144 span .set_status (Status (StatusCode .ERROR , str (err )))
131145 raise err
132-
146+
133147 return traced_method
134148
135149
136150@silently_fail
137151def _set_response_attributes (span , kwargs , result ):
138- set_span_attribute (span , SpanAttributes .LLM_RESPONSE_MODEL , kwargs .get ('modelId' ))
139- set_span_attribute (span , SpanAttributes .LLM_TOP_K , kwargs .get ('additionalModelRequestFields' , {}).get ('top_k' ))
140- content = result .get ('output' , {}).get ('message' , {}).get ('content' , [])
152+ set_span_attribute (span , SpanAttributes .LLM_RESPONSE_MODEL , kwargs .get ("modelId" ))
153+ set_span_attribute (
154+ span ,
155+ SpanAttributes .LLM_TOP_K ,
156+ kwargs .get ("additionalModelRequestFields" , {}).get ("top_k" ),
157+ )
158+ content = result .get ("output" , {}).get ("message" , {}).get ("content" , [])
141159 if len (content ) > 0 :
142- role = result .get ('output' , {}).get ('message' , {}).get ('role' , "assistant" )
143- responses = [
144- {"role" : role , "content" : c .get ('text' , "" )}
145- for c in content
146- ]
160+ role = result .get ("output" , {}).get ("message" , {}).get ("role" , "assistant" )
161+ responses = [{"role" : role , "content" : c .get ("text" , "" )} for c in content ]
147162 set_event_completion (span , responses )
148163
149- if ' usage' in result :
164+ if " usage" in result :
150165 set_span_attributes (
151166 span ,
152167 {
153- SpanAttributes .LLM_USAGE_COMPLETION_TOKENS : result ['usage' ].get ('outputTokens' ),
154- SpanAttributes .LLM_USAGE_PROMPT_TOKENS : result ['usage' ].get ('inputTokens' ),
155- SpanAttributes .LLM_USAGE_TOTAL_TOKENS : result ['usage' ].get ('totalTokens' ),
156- }
168+ SpanAttributes .LLM_USAGE_COMPLETION_TOKENS : result ["usage" ].get (
169+ "outputTokens"
170+ ),
171+ SpanAttributes .LLM_USAGE_PROMPT_TOKENS : result ["usage" ].get (
172+ "inputTokens"
173+ ),
174+ SpanAttributes .LLM_USAGE_TOTAL_TOKENS : result ["usage" ].get (
175+ "totalTokens"
176+ ),
177+ },
178+ )
179+
180+
181+ def patch_aws_bedrock (tracer , version ):
182+ def traced_method (wrapped , instance , args , kwargs ):
183+ if args and args [0 ] != "bedrock-runtime" :
184+ return
185+
186+ client = wrapped (* args , ** kwargs )
187+ print ("Here's the client: " , client )
188+ client .invoke_model = patch_invoke_model (client .invoke_model , tracer , version )
189+ client .invoke_model_with_response_stream = patch_invoke_model (
190+ client .invoke_model_with_response_stream , tracer , version
191+ )
192+ client .converse = patch_invoke_model (client .converse , tracer , version )
193+ client .converse_stream = patch_invoke_model (
194+ client .converse_stream , tracer , version
157195 )
196+ return client
197+
198+ return traced_method
199+
200+
201+ def patch_invoke_model (original_method , tracer , version ):
202+ def traced_method (* args , ** kwargs ):
203+ service_provider = SERVICE_PROVIDERS ["AWS_BEDROCK" ]
204+ span_attributes = {
205+ ** get_langtrace_attributes (
206+ version , service_provider , vendor_type = "framework"
207+ ),
208+ ** get_extra_attributes (),
209+ }
210+ with tracer .start_as_current_span (
211+ name = get_span_name ("aws_bedrock.invoke_model" ),
212+ kind = SpanKind .CLIENT ,
213+ context = set_span_in_context (trace .get_current_span ()),
214+ ) as span :
215+ set_span_attributes (span , span_attributes )
216+ set_invoke_model_attributes (span , kwargs )
217+ response = original_method (* args , ** kwargs )
218+ return response
219+
220+ return traced_method
221+
222+
223+ def set_invoke_model_attributes (span , kwargs ):
224+ modelId = kwargs .get ("modelId" )
225+ (vendor , model_name ) = modelId .split ("." )
226+
227+ print ("Here's the vendor: " , vendor )
228+ print ("Here's the model_name: " , model_name )
229+ print ("Here's the kwargs: " , kwargs )
0 commit comments