Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0
import abc
import inspect
import io
import json
import math
from typing import Any, Dict, Optional
Expand Down Expand Up @@ -320,16 +321,20 @@ def _set_if_not_none(attributes, key, value):

def on_success(self, span: Span, result: Dict[str, Any]):
model_id = self._call_context.params.get(_MODEL_ID)

if not model_id:
return

if "body" in result and isinstance(result["body"], StreamingBody):
original_body = None
try:
# Read the entire content of the StreamingBody
body_content = result["body"].read()
# Decode the bytes to string and parse as JSON
response_body = json.loads(body_content.decode("utf-8"))
original_body = result["body"]
body_content = original_body.read()

# Use one stream for telemetry
stream = io.BytesIO(body_content)
telemetry_content = stream.read()
response_body = json.loads(telemetry_content.decode("utf-8"))
if "amazon.titan" in model_id:
self._handle_amazon_titan_response(span, response_body)
elif "anthropic.claude" in model_id:
Expand All @@ -342,21 +347,24 @@ def on_success(self, span: Span, result: Dict[str, Any]):
self._handle_ai21_jamba_response(span, response_body)
elif "mistral" in model_id:
self._handle_mistral_mistral_response(span, response_body)
# Replenish stream for downstream application use
new_stream = io.BytesIO(body_content)
result["body"] = StreamingBody(new_stream, len(body_content))

except json.JSONDecodeError:
print("Error: Unable to parse the response body as JSON")
except Exception as e: # pylint: disable=broad-exception-caught, invalid-name
print(f"Error processing response: {str(e)}")
finally:
# Make sure to close the stream
result["body"].close()
if original_body is not None:
original_body.close()

# pylint: disable=no-self-use
def _handle_amazon_titan_response(self, span: Span, response_body: Dict[str, Any]):
if "inputTextTokenCount" in response_body:
span.set_attribute(GEN_AI_USAGE_INPUT_TOKENS, response_body["inputTextTokenCount"])

result = response_body["results"][0]
if "results" in response_body:
result = response_body["results"][0]
if "tokenCount" in result:
span.set_attribute(GEN_AI_USAGE_OUTPUT_TOKENS, result["tokenCount"])
if "completionReason" in result:
Expand Down
Loading