diff --git a/agentops/instrumentation/anthropic/attributes/message.py b/agentops/instrumentation/anthropic/attributes/message.py index 624fe04be..c99885b02 100644 --- a/agentops/instrumentation/anthropic/attributes/message.py +++ b/agentops/instrumentation/anthropic/attributes/message.py @@ -60,12 +60,16 @@ def get_message_attributes( MessageStopEvent, MessageStreamEvent, ) + from anthropic import Stream if isinstance(return_value, Message): attributes.update(get_message_response_attributes(return_value)) if hasattr(return_value, "content"): attributes.update(get_tool_attributes(return_value.content)) + elif isinstance(return_value, Stream): + for event in return_value: + attributes.update(get_stream_event_attributes(event)) elif isinstance(return_value, MessageStreamEvent): attributes.update(get_stream_attributes(return_value)) elif isinstance( @@ -511,4 +515,27 @@ def get_stream_event_attributes(event: Any) -> AttributeMap: attributes[SpanAttributes.LLM_RESPONSE_FINISH_REASON] = stop_reason attributes[MessageAttributes.COMPLETION_FINISH_REASON.format(i=0)] = stop_reason + elif event_type == "RawMessageStartEvent": + if hasattr(event, "message"): + if hasattr(event.message, "usage"): + usage = event.message.usage + if hasattr(usage, "input_tokens"): + input_tokens = usage.input_tokens + attributes[SpanAttributes.LLM_USAGE_PROMPT_TOKENS] = input_tokens + + if hasattr(usage, "output_tokens"): + output_tokens = usage.output_tokens + attributes[SpanAttributes.LLM_USAGE_COMPLETION_TOKENS] = output_tokens + + if hasattr(usage, "input_tokens") and hasattr(usage, "output_tokens"): + total_tokens = usage.input_tokens + usage.output_tokens + attributes[SpanAttributes.LLM_USAGE_TOTAL_TOKENS] = total_tokens + + elif event_type == "RawMessageDeltaEvent": + if hasattr(event, "delta"): + if hasattr(event.delta, "stop_reason"): + stop_reason = event.delta.stop_reason + attributes[SpanAttributes.LLM_RESPONSE_STOP_REASON] = stop_reason + attributes[SpanAttributes.LLM_RESPONSE_FINISH_REASON] = stop_reason + attributes[MessageAttributes.COMPLETION_FINISH_REASON.format(i=0)] = stop_reason return attributes diff --git a/tests/unit/instrumentation/anthropic/test_attributes.py b/tests/unit/instrumentation/anthropic/test_attributes.py index 505499eb4..6354af2ea 100644 --- a/tests/unit/instrumentation/anthropic/test_attributes.py +++ b/tests/unit/instrumentation/anthropic/test_attributes.py @@ -86,20 +86,64 @@ def __init__(self): assert attributes[SpanAttributes.LLM_REQUEST_MODEL] == "claude-3-opus-20240229" -def test_get_stream_event_attributes_start(mock_stream_event): - """Test extraction of stream start event attributes.""" - attributes = get_stream_event_attributes(mock_stream_event) - assert attributes[SpanAttributes.LLM_RESPONSE_ID] == "msg_123" - assert attributes[SpanAttributes.LLM_RESPONSE_MODEL] == "claude-3-opus-20240229" - assert attributes[MessageAttributes.COMPLETION_ID.format(i=0)] == "msg_123" +def test_get_stream_event_attributes_sequence(mock_stream_event, mock_message_stop_event): + """Test extraction of attributes from a sequence of stream events.""" + # Test MessageStartEvent + start_attributes = get_stream_event_attributes(mock_stream_event) + assert start_attributes[SpanAttributes.LLM_RESPONSE_ID] == "msg_123" + assert start_attributes[SpanAttributes.LLM_RESPONSE_MODEL] == "claude-3-opus-20240229" + assert start_attributes[MessageAttributes.COMPLETION_ID.format(i=0)] == "msg_123" + # Test MessageStopEvent + stop_attributes = get_stream_event_attributes(mock_message_stop_event) + assert stop_attributes[SpanAttributes.LLM_RESPONSE_STOP_REASON] == "stop_sequence" + assert stop_attributes[SpanAttributes.LLM_RESPONSE_FINISH_REASON] == "stop_sequence" + assert stop_attributes[MessageAttributes.COMPLETION_FINISH_REASON.format(i=0)] == "stop_sequence" -def test_get_stream_event_attributes_stop(mock_message_stop_event): - """Test extraction of stream stop event attributes.""" - attributes = get_stream_event_attributes(mock_message_stop_event) - assert attributes[SpanAttributes.LLM_RESPONSE_STOP_REASON] == "stop_sequence" - assert attributes[SpanAttributes.LLM_RESPONSE_FINISH_REASON] == "stop_sequence" - assert attributes[MessageAttributes.COMPLETION_FINISH_REASON.format(i=0)] == "stop_sequence" + +def test_get_stream_event_attributes_raw_message_start(): + """Test extraction of raw message start event attributes.""" + + class MockUsage: + def __init__(self): + self.input_tokens = 10 + self.output_tokens = 5 + + class MockMessage: + def __init__(self): + self.usage = MockUsage() + + class MockRawMessageStartEvent: + def __init__(self): + self.message = MockMessage() + self.__class__.__name__ = "RawMessageStartEvent" + + event = MockRawMessageStartEvent() + attributes = get_stream_event_attributes(event) + + assert attributes[SpanAttributes.LLM_USAGE_PROMPT_TOKENS] == 10 + assert attributes[SpanAttributes.LLM_USAGE_COMPLETION_TOKENS] == 5 + assert attributes[SpanAttributes.LLM_USAGE_TOTAL_TOKENS] == 15 + + +def test_get_stream_event_attributes_raw_message_delta(): + """Test extraction of raw message delta event attributes.""" + + class MockDelta: + def __init__(self): + self.stop_reason = "end_turn" + + class MockRawMessageDeltaEvent: + def __init__(self): + self.delta = MockDelta() + self.__class__.__name__ = "RawMessageDeltaEvent" + + event = MockRawMessageDeltaEvent() + attributes = get_stream_event_attributes(event) + + assert attributes[SpanAttributes.LLM_RESPONSE_STOP_REASON] == "end_turn" + assert attributes[SpanAttributes.LLM_RESPONSE_FINISH_REASON] == "end_turn" + assert attributes[MessageAttributes.COMPLETION_FINISH_REASON.format(i=0)] == "end_turn" # Tool Attributes Tests @@ -166,3 +210,19 @@ def __init__(self): attributes = get_tool_attributes(content) assert MessageAttributes.COMPLETION_TOOL_CALL_NAME.format(i=0, j=0) in attributes assert attributes[MessageAttributes.COMPLETION_TOOL_CALL_NAME.format(i=0, j=0)] == "calculator" + + +def test_get_message_attributes_with_stream(mock_stream_event, mock_message_stop_event): + """Test extraction of attributes from a Stream object.""" + + # Test MessageStartEvent attributes + start_attributes = get_stream_event_attributes(mock_stream_event) + assert start_attributes[SpanAttributes.LLM_RESPONSE_ID] == "msg_123" + assert start_attributes[SpanAttributes.LLM_RESPONSE_MODEL] == "claude-3-opus-20240229" + assert start_attributes[MessageAttributes.COMPLETION_ID.format(i=0)] == "msg_123" + + # Test MessageStopEvent attributes + stop_attributes = get_stream_event_attributes(mock_message_stop_event) + assert stop_attributes[SpanAttributes.LLM_RESPONSE_STOP_REASON] == "stop_sequence" + assert stop_attributes[SpanAttributes.LLM_RESPONSE_FINISH_REASON] == "stop_sequence" + assert stop_attributes[MessageAttributes.COMPLETION_FINISH_REASON.format(i=0)] == "stop_sequence"