Skip to content

Commit 9d1288d

Browse files
committed
add support for invoke_model_with_stream
1 parent 55e822b commit 9d1288d

File tree

5 files changed

+181
-37
lines changed

5 files changed

+181
-37
lines changed
Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
from examples.awsbedrock_examples.converse import (
2-
use_converse_stream,
3-
use_converse,
42
use_invoke_model_anthropic,
5-
use_invoke_model_cohere,
6-
use_invoke_model_amazon,
3+
use_invoke_model_titan,
4+
use_invoke_model_llama,
75
)
86
from langtrace_python_sdk import langtrace, with_langtrace_root_span
97

@@ -12,8 +10,9 @@ class AWSBedrockRunner:
1210
@with_langtrace_root_span("AWS_Bedrock")
1311
def run(self):
1412

15-
use_converse_stream()
16-
use_converse()
17-
use_invoke_model_anthropic()
18-
use_invoke_model_cohere()
19-
use_invoke_model_amazon()
13+
# use_converse_stream()
14+
# use_converse()
15+
# use_invoke_model_anthropic(stream=True)
16+
# use_invoke_model_cohere()
17+
# use_invoke_model_llama(stream=False)
18+
use_invoke_model_titan(stream=False)

src/examples/awsbedrock_examples/converse.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,12 @@ def use_invoke_model_titan(stream=False):
8888
response = brt.invoke_model_with_response_stream(
8989
body=body, modelId=modelId, accept=accept, contentType=contentType
9090
)
91+
# Extract and print the response text in real-time.
92+
for event in response["body"]:
93+
chunk = json.loads(event["chunk"]["bytes"])
94+
if "outputText" in chunk:
95+
print(chunk["outputText"], end="")
96+
9197
else:
9298
response = brt.invoke_model(
9399
body=body, modelId=modelId, accept=accept, contentType=contentType
@@ -130,7 +136,8 @@ def use_invoke_model_anthropic(stream=False):
130136
for event in stream_response:
131137
chunk = event.get("chunk")
132138
if chunk:
133-
print(json.loads(chunk.get("bytes").decode()))
139+
# print(json.loads(chunk.get("bytes").decode()))
140+
pass
134141

135142
else:
136143
response = brt.invoke_model(
@@ -141,7 +148,7 @@ def use_invoke_model_anthropic(stream=False):
141148
print(response_body.get("completion"))
142149

143150

144-
def use_invoke_model_llama():
151+
def use_invoke_model_llama(stream=False):
145152
model_id = "meta.llama3-8b-instruct-v1:0"
146153
prompt = "What is the capital of France?"
147154
max_gen_len = 128
@@ -157,11 +164,18 @@ def use_invoke_model_llama():
157164
"top_p": top_p,
158165
}
159166
)
160-
response = brt.invoke_model(body=body, modelId=model_id)
161-
162-
response_body = json.loads(response.get("body").read())
163167

164-
return response_body
168+
if stream:
169+
response = brt.invoke_model_with_response_stream(body=body, modelId=model_id)
170+
for event in response["body"]:
171+
chunk = json.loads(event["chunk"]["bytes"])
172+
if "generation" in chunk:
173+
# print(chunk["generation"], end="")
174+
pass
175+
else:
176+
response = brt.invoke_model(body=body, modelId=model_id)
177+
response_body = json.loads(response.get("body").read())
178+
return response_body
165179

166180

167181
# print(get_foundation_models())

src/langtrace_python_sdk/instrumentation/aws_bedrock/patch.py

Lines changed: 150 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@
1616

1717
import json
1818

19-
from langtrace_python_sdk.instrumentation.aws_bedrock.bedrock_streaming_wrapper import (
20-
StreamingWrapper,
21-
)
19+
from wrapt import ObjectProxy
2220
from .stream_body_wrapper import BufferedStreamBody
2321
from functools import wraps
2422
from langtrace.trace_attributes import (
@@ -87,6 +85,11 @@ def traced_method(wrapped, instance, args, kwargs):
8785

8886
client = wrapped(*args, **kwargs)
8987
client.invoke_model = patch_invoke_model(client.invoke_model, tracer, version)
88+
client.invoke_model_with_response_stream = (
89+
patch_invoke_model_with_response_stream(
90+
client.invoke_model_with_response_stream, tracer, version
91+
)
92+
)
9093

9194
client.converse = patch_converse(client.converse, tracer, version)
9295
client.converse_stream = patch_converse_stream(
@@ -186,6 +189,56 @@ def traced_method(*args, **kwargs):
186189
return traced_method
187190

188191

192+
def patch_invoke_model_with_response_stream(original_method, tracer, version):
193+
@wraps(original_method)
194+
def traced_method(*args, **kwargs):
195+
modelId = kwargs.get("modelId")
196+
(vendor, _) = modelId.split(".")
197+
span_attributes = {
198+
**get_langtrace_attributes(version, vendor, vendor_type="framework"),
199+
**get_extra_attributes(),
200+
}
201+
span = tracer.start_span(
202+
name=get_span_name("aws_bedrock.invoke_model_with_response_stream"),
203+
kind=SpanKind.CLIENT,
204+
context=set_span_in_context(trace.get_current_span()),
205+
)
206+
set_span_attributes(span, span_attributes)
207+
response = original_method(*args, **kwargs)
208+
if span.is_recording():
209+
handle_streaming_call(span, kwargs, response)
210+
return response
211+
212+
return traced_method
213+
214+
215+
def handle_streaming_call(span, kwargs, response):
216+
217+
def stream_finished(response_body):
218+
request_body = json.loads(kwargs.get("body"))
219+
220+
(vendor, model) = kwargs.get("modelId").split(".")
221+
222+
set_span_attribute(span, SpanAttributes.LLM_REQUEST_MODEL, model)
223+
set_span_attribute(span, SpanAttributes.LLM_RESPONSE_MODEL, model)
224+
225+
if vendor == "amazon":
226+
set_amazon_attributes(span, request_body, response_body)
227+
228+
if vendor == "anthropic":
229+
if "prompt" in request_body:
230+
set_anthropic_completions_attributes(span, request_body, response_body)
231+
elif "messages" in request_body:
232+
set_anthropic_messages_attributes(span, request_body, response_body)
233+
234+
if vendor == "meta":
235+
set_llama_meta_attributes(span, request_body, response_body)
236+
237+
span.end()
238+
239+
response["body"] = StreamingBedrockWrapper(response["body"], stream_finished)
240+
241+
189242
def handle_call(span, kwargs, response):
190243
modelId = kwargs.get("modelId")
191244
(vendor, model_name) = modelId.split(".")
@@ -195,7 +248,6 @@ def handle_call(span, kwargs, response):
195248
request_body = json.loads(kwargs.get("body"))
196249
response_body = json.loads(response.get("body").read())
197250

198-
set_span_attribute(span, SpanAttributes.LLM_SYSTEM, vendor)
199251
set_span_attribute(span, SpanAttributes.LLM_RESPONSE_MODEL, modelId)
200252
set_span_attribute(span, SpanAttributes.LLM_REQUEST_MODEL, modelId)
201253

@@ -222,12 +274,18 @@ def set_llama_meta_attributes(span, request_body, response_body):
222274
set_span_attribute(
223275
span, SpanAttributes.LLM_REQUEST_MAX_TOKENS, request_body.get("max_gen_len")
224276
)
277+
if "invocation_metrics" in response_body:
278+
input_tokens = response_body.get("invocation_metrics").get("inputTokenCount")
279+
output_tokens = response_body.get("invocation_metrics").get("outputTokenCount")
280+
else:
281+
input_tokens = response_body.get("prompt_token_count")
282+
output_tokens = response_body.get("generation_token_count")
225283

226284
set_usage_attributes(
227285
span,
228286
{
229-
"input_tokens": response_body.get("prompt_token_count"),
230-
"output_tokens": response_body.get("generation_token_count"),
287+
"input_tokens": input_tokens,
288+
"output_tokens": output_tokens,
231289
},
232290
)
233291

@@ -245,7 +303,6 @@ def set_llama_meta_attributes(span, request_body, response_body):
245303
}
246304
]
247305
set_span_attribute(span, SpanAttributes.LLM_PROMPTS, json.dumps(prompts))
248-
print(completions)
249306
set_event_completion(span, completions)
250307

251308

@@ -257,13 +314,22 @@ def set_amazon_attributes(span, request_body, response_body):
257314
"content": request_body.get("inputText"),
258315
}
259316
]
260-
completions = [
261-
{
262-
"role": "assistant",
263-
"content": result.get("outputText"),
264-
}
265-
for result in response_body.get("results")
266-
]
317+
if "results" in response_body:
318+
completions = [
319+
{
320+
"role": "assistant",
321+
"content": result.get("outputText"),
322+
}
323+
for result in response_body.get("results")
324+
]
325+
326+
else:
327+
completions = [
328+
{
329+
"role": "assistant",
330+
"content": response_body.get("outputText"),
331+
}
332+
]
267333
set_span_attribute(
268334
span, SpanAttributes.LLM_REQUEST_MAX_TOKENS, config.get("maxTokenCount")
269335
)
@@ -272,13 +338,19 @@ def set_amazon_attributes(span, request_body, response_body):
272338
)
273339
set_span_attribute(span, SpanAttributes.LLM_REQUEST_TOP_P, config.get("topP"))
274340
set_span_attribute(span, SpanAttributes.LLM_PROMPTS, json.dumps(prompts))
341+
input_tokens = response_body.get("inputTextTokenCount")
342+
if "results" in response_body:
343+
output_tokens = sum(
344+
int(result.get("tokenCount")) for result in response_body.get("results")
345+
)
346+
else:
347+
output_tokens = response_body.get("outputTextTokenCount")
348+
275349
set_usage_attributes(
276350
span,
277351
{
278-
"input_tokens": response_body.get("inputTextTokenCount"),
279-
"output_tokens": sum(
280-
int(result.get("tokenCount")) for result in response_body.get("results")
281-
),
352+
"input_tokens": input_tokens,
353+
"output_tokens": output_tokens,
282354
},
283355
)
284356
set_event_completion(span, completions)
@@ -320,7 +392,7 @@ def set_anthropic_messages_attributes(span, request_body, response_body):
320392
set_span_attribute(
321393
span,
322394
SpanAttributes.LLM_REQUEST_MAX_TOKENS,
323-
request_body.get("max_tokens_to_sample"),
395+
request_body.get("max_tokens_to_sample") or request_body.get("max_tokens"),
324396
)
325397
set_span_attribute(
326398
span,
@@ -394,3 +466,62 @@ def set_span_streaming_response(span, response):
394466
set_event_completion(
395467
span, [{"role": role or "assistant", "content": streaming_response}]
396468
)
469+
470+
471+
class StreamingBedrockWrapper(ObjectProxy):
472+
def __init__(
473+
self,
474+
response,
475+
stream_done_callback=None,
476+
):
477+
super().__init__(response)
478+
479+
self._stream_done_callback = stream_done_callback
480+
self._accumulating_body = {"generation": ""}
481+
482+
def __iter__(self):
483+
for event in self.__wrapped__:
484+
self._process_event(event)
485+
yield event
486+
487+
def _process_event(self, event):
488+
chunk = event.get("chunk")
489+
if not chunk:
490+
return
491+
492+
decoded_chunk = json.loads(chunk.get("bytes").decode())
493+
type = decoded_chunk.get("type")
494+
495+
if type is None and "outputText" in decoded_chunk:
496+
self._stream_done_callback(decoded_chunk)
497+
return
498+
if "generation" in decoded_chunk:
499+
self._accumulating_body["generation"] += decoded_chunk.get("generation")
500+
501+
if type == "message_start":
502+
self._accumulating_body = decoded_chunk.get("message")
503+
elif type == "content_block_start":
504+
self._accumulating_body["content"].append(
505+
decoded_chunk.get("content_block")
506+
)
507+
elif type == "content_block_delta":
508+
self._accumulating_body["content"][-1]["text"] += decoded_chunk.get(
509+
"delta"
510+
).get("text")
511+
512+
elif self.has_finished(type, decoded_chunk):
513+
self._accumulating_body["invocation_metrics"] = decoded_chunk.get(
514+
"amazon-bedrock-invocationMetrics"
515+
)
516+
self._stream_done_callback(self._accumulating_body)
517+
518+
def has_finished(self, type, chunk):
519+
if type and type == "message_stop":
520+
return True
521+
522+
if "completionReason" in chunk and chunk.get("completionReason") == "FINISH":
523+
return True
524+
525+
if "stop_reason" in chunk and chunk.get("stop_reason") is not None:
526+
return True
527+
return False
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "3.3.23"
1+
__version__ = "3.3.24"

src/run_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@
2020
"vertexai": False,
2121
"gemini": False,
2222
"mistral": False,
23-
"awsbedrock": False,
23+
"awsbedrock": True,
2424
"cerebras": False,
25-
"google_genai": True,
25+
"google_genai": False,
2626
}
2727

2828
if ENABLED_EXAMPLES["anthropic"]:

0 commit comments

Comments
 (0)