Skip to content

Commit 31cef4b

Browse files
authored
Merge pull request #326 from Scale3-Labs/release-2.3.4
Release 2.3.4: Add Types for OpenAI & Anthropic Instrumentations + Refactoring accessing attributes
2 parents a9f0083 + 77a1dd1 commit 31cef4b

File tree

20 files changed

+513
-234
lines changed

20 files changed

+513
-234
lines changed

.vscode/settings.json

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1-
21
{
32
"[python]": {
43
"editor.defaultFormatter": "ms-python.black-formatter",
54
},
65
"editor.formatOnSave": true,
6+
"python.testing.pytestArgs": [
7+
"src"
8+
],
9+
"python.testing.unittestEnabled": false,
10+
"python.testing.pytestEnabled": true,
711
}

mypy.ini

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
[mypy]
2+
strict = True
3+
disable_error_code = import-untyped
4+
disallow_untyped_calls = True # Disallow function calls without type annotations
5+
disallow_untyped_defs = True # Disallow defining functions without type annotations
6+
disallow_any_explicit = True # Disallow explicit use of `Any`
7+
disallow_any_generics = True # Disallow generic types without specific type parameters
8+
disallow_incomplete_defs = True # Disallow defining incomplete function signatures
9+
no_implicit_optional = True # Disallow implicitly Optional types
10+
warn_unused_configs = True # Warn about unused configurations
11+
warn_redundant_casts = True # Warn about unnecessary type casts
12+
warn_return_any = True # Warn if a function returns `Any`
13+
warn_unreachable = True # Warn about unreachable code
14+
# Ignore external modules or allow specific imports
15+
follow_imports = skip
16+
ignore_missing_imports = True

src/examples/anthropic_example/completion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
_ = load_dotenv(find_dotenv())
99

10-
langtrace.init()
10+
langtrace.init(write_spans_to_console=True)
1111

1212

1313
@with_langtrace_root_span("messages_create")

src/examples/openai_example/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@ def run(self):
2020
chat_completion_example()
2121
embeddings_create_example()
2222
function_example()
23+
image_edit()

src/langtrace_python_sdk/instrumentation/anthropic/instrumentation.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,38 +16,39 @@
1616

1717
import importlib.metadata
1818
import logging
19-
from typing import Collection
19+
from typing import Collection, Any
2020

2121
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
22+
from opentelemetry.trace import TracerProvider
2223
from opentelemetry.trace import get_tracer
2324
from wrapt import wrap_function_wrapper
24-
25+
from typing import Any
2526
from langtrace_python_sdk.instrumentation.anthropic.patch import messages_create
2627

2728
logging.basicConfig(level=logging.FATAL)
2829

2930

30-
class AnthropicInstrumentation(BaseInstrumentor):
31+
class AnthropicInstrumentation(BaseInstrumentor): # type: ignore[misc]
3132
"""
32-
The AnthropicInstrumentation class represents the Anthropic instrumentation
33+
The AnthropicInstrumentation class represents the Anthropic instrumentation.
3334
"""
3435

3536
def instrumentation_dependencies(self) -> Collection[str]:
3637
return ["anthropic >= 0.19.1"]
3738

38-
def _instrument(self, **kwargs):
39-
tracer_provider = kwargs.get("tracer_provider")
39+
def _instrument(self, **kwargs: dict[str, Any]) -> None:
40+
tracer_provider: TracerProvider = kwargs.get("tracer_provider") # type: ignore
4041
tracer = get_tracer(__name__, "", tracer_provider)
4142
version = importlib.metadata.version("anthropic")
4243

4344
wrap_function_wrapper(
4445
"anthropic.resources.messages",
4546
"Messages.create",
46-
messages_create("anthropic.messages.create", version, tracer),
47+
messages_create(version, tracer),
4748
)
4849

49-
def _instrument_module(self, module_name):
50+
def _instrument_module(self, module_name: str) -> None:
5051
pass
5152

52-
def _uninstrument(self, **kwargs):
53+
def _uninstrument(self, **kwargs: dict[str, Any]) -> None:
5354
pass

src/langtrace_python_sdk/instrumentation/anthropic/patch.py

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@
1414
limitations under the License.
1515
"""
1616

17-
import json
18-
19-
from langtrace.trace_attributes import Event, LLMSpanAttributes
17+
from typing import Any, Callable, Dict, List, Optional, Iterator, TypedDict, Union
18+
from langtrace.trace_attributes import Event, SpanAttributes, LLMSpanAttributes
2019
from langtrace_python_sdk.utils import set_span_attribute
2120
from langtrace_python_sdk.utils.silently_fail import silently_fail
2221

@@ -27,41 +26,48 @@
2726
get_llm_request_attributes,
2827
get_llm_url,
2928
get_span_name,
30-
is_streaming,
3129
set_event_completion,
3230
set_usage_attributes,
31+
set_span_attribute,
3332
)
34-
from opentelemetry.trace import SpanKind
35-
from opentelemetry.trace.status import Status, StatusCode
36-
from langtrace.trace_attributes import SpanAttributes
37-
33+
from opentelemetry.trace import Span, Tracer, SpanKind
34+
from opentelemetry.trace.status import StatusCode
3835
from langtrace_python_sdk.constants.instrumentation.anthropic import APIS
39-
from langtrace_python_sdk.constants.instrumentation.common import (
40-
SERVICE_PROVIDERS,
36+
from langtrace_python_sdk.constants.instrumentation.common import SERVICE_PROVIDERS
37+
from langtrace_python_sdk.instrumentation.anthropic.types import (
38+
StreamingResult,
39+
ResultType,
40+
MessagesCreateKwargs,
41+
ContentItem,
42+
Usage,
4143
)
4244

4345

44-
def messages_create(original_method, version, tracer):
46+
def messages_create(version: str, tracer: Tracer) -> Callable[..., Any]:
4547
"""Wrap the `messages_create` method."""
4648

47-
def traced_method(wrapped, instance, args, kwargs):
49+
def traced_method(
50+
wrapped: Callable[..., Any],
51+
instance: Any,
52+
args: List[Any],
53+
kwargs: MessagesCreateKwargs,
54+
) -> Any:
4855
service_provider = SERVICE_PROVIDERS["ANTHROPIC"]
4956

50-
# extract system from kwargs and attach as a role to the prompts
51-
# we do this to keep it consistent with the openai
57+
# Extract system from kwargs and attach as a role to the prompts
5258
prompts = kwargs.get("messages", [])
5359
system = kwargs.get("system")
5460
if system:
5561
prompts = [{"role": "system", "content": system}] + kwargs.get(
5662
"messages", []
5763
)
58-
64+
extraAttributes = get_extra_attributes()
5965
span_attributes = {
6066
**get_langtrace_attributes(version, service_provider),
6167
**get_llm_request_attributes(kwargs, prompts=prompts),
6268
**get_llm_url(instance),
6369
SpanAttributes.LLM_PATH: APIS["MESSAGES_CREATE"]["ENDPOINT"],
64-
**get_extra_attributes(),
70+
**extraAttributes, # type: ignore
6571
}
6672

6773
attributes = LLMSpanAttributes(**span_attributes)
@@ -74,37 +80,35 @@ def traced_method(wrapped, instance, args, kwargs):
7480
try:
7581
# Attempt to call the original method
7682
result = wrapped(*args, **kwargs)
77-
return set_response_attributes(result, span, kwargs)
83+
return set_response_attributes(result, span)
7884

7985
except Exception as err:
8086
# Record the exception in the span
8187
span.record_exception(err)
8288
# Set the span status to indicate an error
83-
span.set_status(Status(StatusCode.ERROR, str(err)))
89+
span.set_status(StatusCode.ERROR, str(err))
8490
# Reraise the exception to ensure it's not swallowed
8591
span.end()
8692
raise
8793

88-
@silently_fail
89-
def set_response_attributes(result, span, kwargs):
90-
if not is_streaming(kwargs):
94+
def set_response_attributes(
95+
result: Union[ResultType, StreamingResult], span: Span
96+
) -> Any:
97+
if not isinstance(result, Iterator):
9198
if hasattr(result, "content") and result.content is not None:
9299
set_span_attribute(
93100
span, SpanAttributes.LLM_RESPONSE_MODEL, result.model
94101
)
102+
content_item = result.content[0]
95103
completion = [
96104
{
97-
"role": result.role if result.role else "assistant",
98-
"content": result.content[0].text,
99-
"type": result.content[0].type,
105+
"role": result.role or "assistant",
106+
"content": content_item.text,
107+
"type": content_item.type,
100108
}
101109
]
102110
set_event_completion(span, completion)
103111

104-
else:
105-
responses = []
106-
set_event_completion(span, responses)
107-
108112
if (
109113
hasattr(result, "system_fingerprint")
110114
and result.system_fingerprint is not None
@@ -116,7 +120,7 @@ def set_response_attributes(result, span, kwargs):
116120
# Get the usage
117121
if hasattr(result, "usage") and result.usage is not None:
118122
usage = result.usage
119-
set_usage_attributes(span, dict(usage))
123+
set_usage_attributes(span, vars(usage))
120124

121125
span.set_status(StatusCode.OK)
122126
span.end()
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
"""
2+
Copyright (c) 2024 Scale3 Labs
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
Unless required by applicable law or agreed to in writing, software
8+
distributed under the License is distributed on an "AS IS" BASIS,
9+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10+
See the License for the specific language governing permissions and
11+
limitations under the License.
12+
"""
13+
14+
from typing import Dict, List, Optional, Iterator, TypedDict
15+
16+
17+
class MessagesCreateKwargs(TypedDict, total=False):
18+
system: Optional[str]
19+
messages: List[Dict[str, str]]
20+
21+
22+
class Usage:
23+
input_tokens: int
24+
output_tokens: int
25+
26+
def __init__(self, input_tokens: int, output_tokens: int):
27+
self.input_tokens = input_tokens
28+
self.output_tokens = output_tokens
29+
30+
31+
class Message:
32+
def __init__(
33+
self,
34+
id: str,
35+
model: Optional[str],
36+
usage: Optional[Usage],
37+
):
38+
self.id = id
39+
self.model = model
40+
self.usage = usage
41+
42+
model: Optional[str]
43+
usage: Optional[Usage]
44+
45+
46+
class Delta:
47+
text: Optional[str]
48+
49+
def __init__(
50+
self,
51+
text: Optional[str],
52+
):
53+
self.text = text
54+
55+
56+
class Chunk:
57+
message: Message
58+
delta: Delta
59+
60+
def __init__(
61+
self,
62+
message: Message,
63+
delta: Delta,
64+
):
65+
self.message = message
66+
self.delta = delta
67+
68+
69+
class ContentItem:
70+
role: str
71+
content: str
72+
text: str
73+
type: str
74+
75+
def __init__(self, role: str, content: str, text: str, type: str):
76+
self.role = role
77+
self.content = content
78+
self.text = text
79+
self.type = type
80+
81+
82+
class ResultType:
83+
model: Optional[str]
84+
role: Optional[str]
85+
content: List[ContentItem]
86+
system_fingerprint: Optional[str]
87+
usage: Optional[Usage]
88+
89+
def __init__(
90+
self,
91+
model: Optional[str],
92+
role: Optional[str],
93+
content: Optional[List[ContentItem]],
94+
system_fingerprint: Optional[str],
95+
usage: Optional[Usage],
96+
):
97+
self.model = model
98+
self.role = role
99+
self.content = content if content is not None else []
100+
self.system_fingerprint = system_fingerprint
101+
self.usage = usage
102+
103+
104+
# The result would be an iterator that yields these Chunk objects
105+
StreamingResult = Iterator[Chunk]

src/langtrace_python_sdk/instrumentation/cohere/patch.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -367,10 +367,7 @@ def traced_method(wrapped, instance, args, kwargs):
367367
}
368368
for item in chat_history
369369
]
370-
if len(history) > 0:
371-
prompts = history + prompts
372-
if len(system_prompts) > 0:
373-
prompts = system_prompts + prompts
370+
prompts = system_prompts + history + prompts
374371

375372
span_attributes = {
376373
**get_langtrace_attributes(version, service_provider),

src/langtrace_python_sdk/instrumentation/crewai/patch.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88

99
from langtrace_python_sdk.constants import LANGTRACE_SDK_NAME
1010
from langtrace_python_sdk.constants.instrumentation.common import (
11-
LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY, SERVICE_PROVIDERS)
11+
LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY,
12+
SERVICE_PROVIDERS,
13+
)
1214
from langtrace_python_sdk.utils import set_span_attribute
1315
from langtrace_python_sdk.utils.llm import get_span_name, set_span_attributes
1416
from langtrace_python_sdk.utils.misc import serialize_args, serialize_kwargs
@@ -44,7 +46,9 @@ def traced_method(wrapped, instance, args, kwargs):
4446
set_span_attributes(span, attributes)
4547
result = wrapped(*args, **kwargs)
4648
if result is not None and len(result) > 0:
47-
set_span_attribute(span, "crewai.memory.storage.rag_storage.outputs", str(result))
49+
set_span_attribute(
50+
span, "crewai.memory.storage.rag_storage.outputs", str(result)
51+
)
4852
if result:
4953
span.set_status(Status(StatusCode.OK))
5054
span.end()
@@ -87,20 +91,17 @@ def traced_method(wrapped, instance, args, kwargs):
8791
CrewAISpanAttributes(span=span, instance=instance)
8892
result = wrapped(*args, **kwargs)
8993
if result:
94+
class_name = instance.__class__.__name__
95+
span.set_attribute(
96+
f"crewai.{class_name.lower()}.result", str(result)
97+
)
9098
span.set_status(Status(StatusCode.OK))
91-
if instance.__class__.__name__ == "Crew":
92-
span.set_attribute("crewai.crew.result", str(result))
93-
if hasattr(result, "tasks_output") and result.tasks_output is not None:
94-
span.set_attribute("crewai.crew.tasks_output", str((result.tasks_output)))
95-
if hasattr(result, "token_usage") and result.token_usage is not None:
96-
span.set_attribute("crewai.crew.token_usage", str((result.token_usage)))
97-
if hasattr(result, "usage_metrics") and result.usage_metrics is not None:
98-
span.set_attribute("crewai.crew.usage_metrics", str((result.usage_metrics)))
99-
elif instance.__class__.__name__ == "Agent":
100-
span.set_attribute("crewai.agent.result", str(result))
101-
elif instance.__class__.__name__ == "Task":
102-
span.set_attribute("crewai.task.result", str(result))
103-
99+
if class_name == "Crew":
100+
for attr in ["tasks_output", "token_usage", "usage_metrics"]:
101+
if hasattr(result, attr):
102+
span.set_attribute(
103+
f"crewai.crew.{attr}", str(getattr(result, attr))
104+
)
104105
span.end()
105106
return result
106107

0 commit comments

Comments
 (0)