Skip to content

Commit cc8921c

Browse files
committed
add anthropic typings
1 parent 7b95911 commit cc8921c

File tree

5 files changed

+144
-79
lines changed

5 files changed

+144
-79
lines changed

mypy.ini

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
[mypy]
2+
strict = True
3+
disable_error_code = import-untyped
4+
disallow_untyped_calls = True # Disallow function calls without type annotations
5+
disallow_untyped_defs = True # Disallow defining functions without type annotations
6+
disallow_any_explicit = True # Disallow explicit use of `Any`
7+
disallow_any_generics = True # Disallow generic types without specific type parameters
8+
disallow_incomplete_defs = True # Disallow defining incomplete function signatures
9+
no_implicit_optional = True # Disallow implicitly Optional types
10+
warn_unused_configs = True # Warn about unused configurations
11+
warn_redundant_casts = True # Warn about unnecessary type casts
12+
warn_return_any = True # Warn if a function returns `Any`
13+
warn_unreachable = True # Warn about unreachable code
14+
# Ignore external modules or allow specific imports
15+
follow_imports = skip
16+
ignore_missing_imports = True

src/langtrace_python_sdk/instrumentation/anthropic/instrumentation.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,38 +16,39 @@
1616

1717
import importlib.metadata
1818
import logging
19-
from typing import Collection
19+
from typing import Collection, Any
2020

2121
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
22+
from opentelemetry.trace import TracerProvider
2223
from opentelemetry.trace import get_tracer
2324
from wrapt import wrap_function_wrapper
24-
25-
from langtrace_python_sdk.instrumentation.anthropic.patch import messages_create
25+
from typing import Any
26+
from src.langtrace_python_sdk.instrumentation.anthropic.patch import messages_create
2627

2728
logging.basicConfig(level=logging.FATAL)
2829

2930

30-
class AnthropicInstrumentation(BaseInstrumentor):
31+
class AnthropicInstrumentation(BaseInstrumentor): # type: ignore[misc]
3132
"""
32-
The AnthropicInstrumentation class represents the Anthropic instrumentation
33+
The AnthropicInstrumentation class represents the Anthropic instrumentation.
3334
"""
3435

3536
def instrumentation_dependencies(self) -> Collection[str]:
3637
return ["anthropic >= 0.19.1"]
3738

38-
def _instrument(self, **kwargs):
39-
tracer_provider = kwargs.get("tracer_provider")
39+
def _instrument(self, **kwargs: dict[str, Any]) -> None:
40+
tracer_provider: TracerProvider = kwargs.get("tracer_provider") # type: ignore
4041
tracer = get_tracer(__name__, "", tracer_provider)
4142
version = importlib.metadata.version("anthropic")
4243

4344
wrap_function_wrapper(
4445
"anthropic.resources.messages",
4546
"Messages.create",
46-
messages_create("anthropic.messages.create", version, tracer),
47+
messages_create(version, tracer),
4748
)
4849

49-
def _instrument_module(self, module_name):
50+
def _instrument_module(self, module_name: str) -> None:
5051
pass
5152

52-
def _uninstrument(self, **kwargs):
53+
def _uninstrument(self, **kwargs: dict[str, Any]) -> None:
5354
pass

src/langtrace_python_sdk/instrumentation/anthropic/patch.py

Lines changed: 79 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -14,51 +14,78 @@
1414
limitations under the License.
1515
"""
1616

17-
import json
18-
19-
from langtrace.trace_attributes import Event, LLMSpanAttributes
20-
from langtrace_python_sdk.utils import set_span_attribute, silently_fail
17+
from typing import Any, Callable, Dict, List, Optional, Iterator, TypedDict, Union
18+
from langtrace.trace_attributes import Event, SpanAttributes, LLMSpanAttributes
2119
from langtrace_python_sdk.utils.llm import (
2220
get_extra_attributes,
2321
get_langtrace_attributes,
2422
get_llm_request_attributes,
2523
get_llm_url,
26-
is_streaming,
2724
set_event_completion,
2825
set_event_completion_chunk,
2926
set_usage_attributes,
27+
set_span_attribute
3028
)
31-
from opentelemetry.trace import SpanKind
32-
from opentelemetry.trace.status import Status, StatusCode
33-
from langtrace.trace_attributes import SpanAttributes
34-
35-
from langtrace_python_sdk.constants.instrumentation.anthropic import APIS
36-
from langtrace_python_sdk.constants.instrumentation.common import (
37-
SERVICE_PROVIDERS,
38-
)
29+
from opentelemetry.trace import Span, Tracer, SpanKind
30+
from opentelemetry.trace.status import StatusCode
31+
from src.langtrace_python_sdk.constants.instrumentation.anthropic import APIS
32+
from src.langtrace_python_sdk.constants.instrumentation.common import SERVICE_PROVIDERS
33+
from src.langtrace_python_sdk.instrumentation.anthropic.types import StreamingResult, ResultType, MessagesCreateKwargs, ContentItem, Usage
34+
35+
def handle_streaming_response(result: StreamingResult, span: Span) -> Iterator[str]:
36+
result_content: List[str] = []
37+
span.add_event(Event.STREAM_START.value)
38+
input_tokens: int = 0
39+
output_tokens: int = 0
40+
try:
41+
for chunk in result:
42+
if chunk['message']["model"] is not None:
43+
span.set_attribute(
44+
SpanAttributes.LLM_RESPONSE_MODEL, chunk["message"]["model"]
45+
)
46+
content: str = ""
47+
if chunk["delta"].get("text") is not None:
48+
content = chunk["delta"]["text"] or ""
49+
result_content.append(content if len(content) > 0 else "")
50+
51+
if chunk["message"]["usage"] is not None:
52+
usage = chunk["message"]["usage"]
53+
input_tokens += usage.get("input_tokens", 0)
54+
output_tokens += usage.get("output_tokens", 0)
55+
56+
if content:
57+
set_event_completion_chunk(span, "".join(content))
58+
59+
yield content
60+
finally:
61+
span.add_event(Event.STREAM_END.value)
62+
set_usage_attributes(
63+
span, {"input_tokens": input_tokens, "output_tokens": output_tokens}
64+
)
65+
completion: List[Dict[str, str]] = [{"role": "assistant", "content": "".join(result_content)}]
66+
set_event_completion(span, completion)
3967

68+
span.set_status(StatusCode.OK)
69+
span.end()
4070

41-
def messages_create(original_method, version, tracer):
71+
def messages_create(version: str, tracer: Tracer) -> Callable[..., Any]:
4272
"""Wrap the `messages_create` method."""
4373

44-
def traced_method(wrapped, instance, args, kwargs):
74+
def traced_method(wrapped: Callable[..., Any], instance: Any, args: List[Any], kwargs: MessagesCreateKwargs) -> Any:
4575
service_provider = SERVICE_PROVIDERS["ANTHROPIC"]
4676

47-
# extract system from kwargs and attach as a role to the prompts
48-
# we do this to keep it consistent with the openai
77+
# Extract system from kwargs and attach as a role to the prompts
4978
prompts = kwargs.get("messages", [])
5079
system = kwargs.get("system")
5180
if system:
52-
prompts = [{"role": "system", "content": system}] + kwargs.get(
53-
"messages", []
54-
)
55-
81+
prompts = [{"role": "system", "content": system}] + kwargs.get("messages", [])
82+
extraAttributes = get_extra_attributes()
5683
span_attributes = {
5784
**get_langtrace_attributes(version, service_provider),
5885
**get_llm_request_attributes(kwargs, prompts=prompts),
5986
**get_llm_url(instance),
6087
SpanAttributes.LLM_PATH: APIS["MESSAGES_CREATE"]["ENDPOINT"],
61-
**get_extra_attributes(),
88+
**extraAttributes,
6289
}
6390

6491
attributes = LLMSpanAttributes(**span_attributes)
@@ -77,56 +104,43 @@ def traced_method(wrapped, instance, args, kwargs):
77104
# Record the exception in the span
78105
span.record_exception(err)
79106
# Set the span status to indicate an error
80-
span.set_status(Status(StatusCode.ERROR, str(err)))
107+
span.set_status(StatusCode.ERROR, str(err))
81108
# Reraise the exception to ensure it's not swallowed
82109
span.end()
83110
raise
84111

85-
def handle_streaming_response(result, span):
112+
def handle_streaming_response(result: StreamingResult, span: Span) -> Iterator[str]:
86113
"""Process and yield streaming response chunks."""
87-
result_content = []
114+
result_content: List[str] = []
88115
span.add_event(Event.STREAM_START.value)
89-
input_tokens = 0
90-
output_tokens = 0
116+
input_tokens: int = 0
117+
output_tokens: int = 0
91118
try:
92119
for chunk in result:
93-
if (
94-
hasattr(chunk, "message")
95-
and chunk.message is not None
96-
and hasattr(chunk.message, "model")
97-
and chunk.message.model is not None
98-
):
99-
span.set_attribute(
100-
SpanAttributes.LLM_RESPONSE_MODEL, chunk.message.model
101-
)
102-
content = ""
103-
if hasattr(chunk, "delta") and chunk.delta is not None:
104-
content = chunk.delta.text if hasattr(chunk.delta, "text") else ""
105-
# Assuming content needs to be aggregated before processing
120+
span.set_attribute(
121+
SpanAttributes.LLM_RESPONSE_MODEL, chunk["message"]["model"] or ""
122+
)
123+
content: str = ""
124+
if hasattr(chunk, "delta") and chunk["delta"] is not None:
125+
content = chunk["delta"]["text"] or ""
106126
result_content.append(content if len(content) > 0 else "")
107-
108-
if hasattr(chunk, "message") and hasattr(chunk.message, "usage"):
127+
if chunk["message"]["usage"] is not None:
109128
input_tokens += (
110-
chunk.message.usage.input_tokens
111-
if hasattr(chunk.message.usage, "input_tokens")
129+
chunk["message"]["usage"]["input_tokens"]
130+
if hasattr(chunk["message"]["usage"], "input_tokens")
112131
else 0
113132
)
114133
output_tokens += (
115-
chunk.message.usage.output_tokens
116-
if hasattr(chunk.message.usage, "output_tokens")
134+
chunk["message"]["usage"]["output_tokens"]
135+
if hasattr(chunk["message"]["usage"], "output_tokens")
117136
else 0
118137
)
119138

120-
# Assuming span.add_event is part of a larger logging or event system
121-
# Add event for each chunk of content
122139
if content:
123140
set_event_completion_chunk(span, "".join(content))
124141

125-
# Assuming this is part of a generator, yield chunk or aggregated content
126142
yield content
127143
finally:
128-
129-
# Finalize span after processing all chunks
130144
span.add_event(Event.STREAM_END.value)
131145
set_usage_attributes(
132146
span, {"input_tokens": input_tokens, "output_tokens": output_tokens}
@@ -137,36 +151,34 @@ def handle_streaming_response(result, span):
137151
span.set_status(StatusCode.OK)
138152
span.end()
139153

140-
def set_response_attributes(result, span, kwargs):
141-
if not is_streaming(kwargs):
142-
if hasattr(result, "content") and result.content is not None:
154+
def set_response_attributes(result: Union[ResultType, StreamingResult], span: Span, kwargs: MessagesCreateKwargs) -> Any:
155+
if not isinstance(result, Iterator):
156+
if result["content"] is not None:
143157
set_span_attribute(
144-
span, SpanAttributes.LLM_RESPONSE_MODEL, result.model
158+
span, SpanAttributes.LLM_RESPONSE_MODEL, result["model"]
145159
)
146-
completion = [
160+
content_item = result["content"][0]
161+
completion: List[ContentItem] = [
147162
{
148-
"role": result.role if result.role else "assistant",
149-
"content": result.content[0].text,
150-
"type": result.content[0].type,
163+
"role": result["role"] or "assistant",
164+
"content": content_item.get("text", ""),
165+
"type": content_item.get("type", ""),
151166
}
152167
]
153168
set_event_completion(span, completion)
154169

155170
else:
156-
responses = []
171+
responses: List[ContentItem] = []
157172
set_event_completion(span, responses)
158173

159-
if (
160-
hasattr(result, "system_fingerprint")
161-
and result.system_fingerprint is not None
162-
):
174+
if result["system_fingerprint"] is not None:
163175
span.set_attribute(
164176
SpanAttributes.LLM_SYSTEM_FINGERPRINT,
165-
result.system_fingerprint,
177+
result["system_fingerprint"],
166178
)
167179
# Get the usage
168-
if hasattr(result, "usage") and result.usage is not None:
169-
usage = result.usage
180+
if result["usage"] is not None:
181+
usage: Usage = result["usage"]
170182
set_usage_attributes(span, dict(usage))
171183

172184
span.set_status(StatusCode.OK)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from typing import Dict, List, Optional, Iterator, TypedDict
2+
class MessagesCreateKwargs(TypedDict, total=False):
3+
system: str
4+
messages: List[Dict[str, str]]
5+
6+
class Usage(TypedDict, total=True):
7+
input_tokens: int
8+
output_tokens: int
9+
10+
class Message(TypedDict, total=True):
11+
model: Optional[str]
12+
usage: Optional[Usage]
13+
14+
class Delta(TypedDict, total=True):
15+
text: Optional[str]
16+
17+
class Chunk(TypedDict, total=True):
18+
message: Message
19+
delta: Delta
20+
21+
class ContentItem(TypedDict, total=False):
22+
role: str
23+
content: str
24+
text: str
25+
type: str
26+
27+
class ResultType(TypedDict, total=True):
28+
model: Optional[str]
29+
role: Optional[str]
30+
content: List[ContentItem]
31+
system_fingerprint: Optional[str]
32+
usage: Optional[Usage]
33+
34+
# The result would be an iterator that yields these Chunk objects
35+
StreamingResult = Iterator[Chunk]

src/langtrace_python_sdk/utils/llm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
limitations under the License.
1515
"""
1616

17+
from typing import Any, Dict
1718
from langtrace_python_sdk.constants import LANGTRACE_SDK_NAME
1819
from langtrace_python_sdk.utils import set_span_attribute
1920
from openai import NOT_GIVEN
@@ -136,9 +137,9 @@ def get_llm_request_attributes(kwargs, prompts=None, model=None, operation_name=
136137
}
137138

138139

139-
def get_extra_attributes():
140+
def get_extra_attributes() -> Dict[str, Any]:
140141
extra_attributes = baggage.get_baggage(LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY)
141-
return extra_attributes or {}
142+
return extra_attributes.__dict__ or {}
142143

143144

144145
def get_llm_url(instance):

0 commit comments

Comments
 (0)