Skip to content

Commit 736fb60

Browse files
committed
Merge branch 'main' of github.com:Scale3-Labs/langtrace-python-sdk into release
2 parents a7b34ab + c792eb3 commit 736fb60

File tree

11 files changed

+335
-12
lines changed

11 files changed

+335
-12
lines changed
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
from pymilvus import MilvusClient, model
2+
from typing import List
3+
from langtrace_python_sdk import langtrace, with_langtrace_root_span
4+
from dotenv import load_dotenv
5+
6+
load_dotenv()
7+
langtrace.init()
8+
9+
client = MilvusClient("milvus_demo.db")
10+
11+
COLLECTION_NAME = "demo_collection"
12+
embedding_fn = model.DefaultEmbeddingFunction()
13+
14+
15+
def create_collection(collection_name: str = COLLECTION_NAME):
16+
if client.has_collection(collection_name=collection_name):
17+
client.drop_collection(collection_name=collection_name)
18+
19+
client.create_collection(
20+
collection_name=collection_name,
21+
dimension=768, # The vectors we will use in this demo has 768 dimensions
22+
)
23+
24+
25+
def create_embedding(docs: List[str] = [], subject: str = "history"):
26+
"""
27+
Create embeddings for the given documents.
28+
"""
29+
30+
vectors = embedding_fn.encode_documents(docs)
31+
# Each entity has id, vector representation, raw text, and a subject label that we use
32+
# to demo metadata filtering later.
33+
data = [
34+
{"id": i, "vector": vectors[i], "text": docs[i], "subject": subject}
35+
for i in range(len(vectors))
36+
]
37+
# print("Data has", len(data), "entities, each with fields: ", data[0].keys())
38+
# print("Vector dim:", len(data[0]["vector"]))
39+
return data
40+
41+
42+
def insert_data(collection_name: str = COLLECTION_NAME, data: List[dict] = []):
43+
client.insert(
44+
collection_name=collection_name,
45+
data=data,
46+
)
47+
48+
49+
def vector_search(collection_name: str = COLLECTION_NAME, queries: List[str] = []):
50+
query_vectors = embedding_fn.encode_queries(queries)
51+
# If you don't have the embedding function you can use a fake vector to finish the demo:
52+
# query_vectors = [ [ random.uniform(-1, 1) for _ in range(768) ] ]
53+
54+
res = client.search(
55+
collection_name="demo_collection", # target collection
56+
data=query_vectors, # query vectors
57+
limit=2, # number of returned entities
58+
output_fields=["text", "subject"], # specifies fields to be returned
59+
timeout=10,
60+
partition_names=["history"],
61+
anns_field="vector",
62+
search_params={"nprobe": 10},
63+
)
64+
65+
66+
def query(collection_name: str = COLLECTION_NAME, query: str = ""):
67+
res = client.query(
68+
collection_name=collection_name,
69+
filter=query,
70+
# output_fields=["text", "subject"],
71+
)
72+
73+
# print(res)
74+
75+
76+
@with_langtrace_root_span("milvus_example")
77+
def main():
78+
create_collection()
79+
# insert Alan Turing's history
80+
turing_data = create_embedding(
81+
docs=[
82+
"Artificial intelligence was founded as an academic discipline in 1956.",
83+
"Alan Turing was the first person to conduct substantial research in AI.",
84+
"Born in Maida Vale, London, Turing was raised in southern England.",
85+
]
86+
)
87+
insert_data(data=turing_data)
88+
89+
# insert AI Drug Discovery
90+
drug_data = create_embedding(
91+
docs=[
92+
"Machine learning has been used for drug design.",
93+
"Computational synthesis with AI algorithms predicts molecular properties.",
94+
"DDR1 is involved in cancers and fibrosis.",
95+
],
96+
subject="biology",
97+
)
98+
insert_data(data=drug_data)
99+
100+
vector_search(queries=["Who is Alan Turing?"])
101+
query(query="subject == 'history'")
102+
query(query="subject == 'biology'")
103+
104+
105+
if __name__ == "__main__":
106+
main()

src/langtrace_python_sdk/constants/instrumentation/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
"MONGODB": "MongoDB",
3838
"AWS_BEDROCK": "AWS Bedrock",
3939
"CEREBRAS": "Cerebras",
40+
"MILVUS": "Milvus",
4041
}
4142

4243
LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY = "langtrace_additional_attributes"
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
APIS = {
2+
"INSERT": {
3+
"MODULE": "pymilvus",
4+
"METHOD": "MilvusClient.insert",
5+
"OPERATION": "insert",
6+
"SPAN_NAME": "Milvus Insert",
7+
},
8+
"QUERY": {
9+
"MODULE": "pymilvus",
10+
"METHOD": "MilvusClient.query",
11+
"OPERATION": "query",
12+
"SPAN_NAME": "Milvus Query",
13+
},
14+
"SEARCH": {
15+
"MODULE": "pymilvus",
16+
"METHOD": "MilvusClient.search",
17+
"OPERATION": "search",
18+
"SPAN_NAME": "Milvus Search",
19+
},
20+
"DELETE": {
21+
"MODULE": "pymilvus",
22+
"METHOD": "MilvusClient.delete",
23+
"OPERATION": "delete",
24+
"SPAN_NAME": "Milvus Delete",
25+
},
26+
"CREATE_COLLECTION": {
27+
"MODULE": "pymilvus",
28+
"METHOD": "MilvusClient.create_collection",
29+
"OPERATION": "create_collection",
30+
"SPAN_NAME": "Milvus Create Collection",
31+
},
32+
"UPSERT": {
33+
"MODULE": "pymilvus",
34+
"METHOD": "MilvusClient.upsert",
35+
"OPERATION": "upsert",
36+
"SPAN_NAME": "Milvus Upsert",
37+
},
38+
}

src/langtrace_python_sdk/instrumentation/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .litellm import LiteLLMInstrumentation
2424
from .pymongo import PyMongoInstrumentation
2525
from .cerebras import CerebrasInstrumentation
26+
from .milvus import MilvusInstrumentation
2627

2728
__all__ = [
2829
"AnthropicInstrumentation",
@@ -50,4 +51,5 @@
5051
"PyMongoInstrumentation",
5152
"AWSBedrockInstrumentation",
5253
"CerebrasInstrumentation",
54+
"MilvusInstrumentation",
5355
]

src/langtrace_python_sdk/instrumentation/langchain_community/patch.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -71,16 +71,26 @@ def traced_method(wrapped, instance, args, kwargs):
7171
result = wrapped(*args, **kwargs)
7272
if trace_output:
7373
span.set_attribute("langchain.outputs", to_json_string(result))
74-
75-
prompt_tokens = instance.get_num_tokens(args[0])
76-
completion_tokens = instance.get_num_tokens(result)
77-
if hasattr(result, 'usage'):
74+
prompt_tokens = (
75+
instance.get_num_tokens(args[0])
76+
if hasattr(instance, "get_num_tokens")
77+
else None
78+
)
79+
completion_tokens = (
80+
instance.get_num_tokens(result)
81+
if hasattr(instance, "get_num_tokens")
82+
else None
83+
)
84+
if hasattr(result, "usage"):
7885
prompt_tokens = result.usage.prompt_tokens
7986
completion_tokens = result.usage.completion_tokens
8087

81-
span.set_attribute(SpanAttributes.LLM_USAGE_COMPLETION_TOKENS, prompt_tokens)
82-
span.set_attribute(SpanAttributes.LLM_USAGE_PROMPT_TOKENS, completion_tokens)
83-
88+
span.set_attribute(
89+
SpanAttributes.LLM_USAGE_COMPLETION_TOKENS, prompt_tokens
90+
)
91+
span.set_attribute(
92+
SpanAttributes.LLM_USAGE_PROMPT_TOKENS, completion_tokens
93+
)
8494

8595
span.set_status(StatusCode.OK)
8696
return result
@@ -102,9 +112,17 @@ def clean_empty(d):
102112
if not isinstance(d, (dict, list, tuple)):
103113
return d
104114
if isinstance(d, tuple):
105-
return tuple(val for val in (clean_empty(val) for val in d) if val != () and val is not None)
115+
return tuple(
116+
val
117+
for val in (clean_empty(val) for val in d)
118+
if val != () and val is not None
119+
)
106120
if isinstance(d, list):
107-
return [val for val in (clean_empty(val) for val in d) if val != [] and val is not None]
121+
return [
122+
val
123+
for val in (clean_empty(val) for val in d)
124+
if val != [] and val is not None
125+
]
108126
result = {}
109127
for k, val in d.items():
110128
if isinstance(val, dict):
@@ -120,7 +138,7 @@ def clean_empty(d):
120138
result[k] = val.strip()
121139
elif isinstance(val, object):
122140
# some langchain objects have a text attribute
123-
val = getattr(val, 'text', None)
141+
val = getattr(val, "text", None)
124142
if val is not None and val.strip() != "":
125143
result[k] = val.strip()
126144
return result
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .instrumentation import MilvusInstrumentation
2+
3+
__all__ = ["MilvusInstrumentation"]
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
2+
from opentelemetry.trace import get_tracer
3+
4+
from typing import Collection
5+
from importlib_metadata import version as v
6+
from wrapt import wrap_function_wrapper as _W
7+
8+
from langtrace_python_sdk.constants.instrumentation.milvus import APIS
9+
from .patch import generic_patch
10+
11+
12+
class MilvusInstrumentation(BaseInstrumentor):
13+
14+
def instrumentation_dependencies(self) -> Collection[str]:
15+
return ["pymilvus >= 2.4.1"]
16+
17+
def _instrument(self, **kwargs):
18+
tracer_provider = kwargs.get("tracer_provider")
19+
tracer = get_tracer(__name__, "", tracer_provider)
20+
version = v("pymilvus")
21+
for api in APIS.values():
22+
_W(
23+
module=api["MODULE"],
24+
name=api["METHOD"],
25+
wrapper=generic_patch(api, version, tracer),
26+
)
27+
28+
def _uninstrument(self, **kwargs):
29+
pass
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
from langtrace_python_sdk.utils.silently_fail import silently_fail
2+
from opentelemetry.trace import Tracer
3+
from opentelemetry.trace import SpanKind
4+
from langtrace_python_sdk.utils import handle_span_error, set_span_attribute
5+
from langtrace_python_sdk.utils.llm import (
6+
get_extra_attributes,
7+
set_span_attributes,
8+
)
9+
import json
10+
11+
12+
def generic_patch(api, version: str, tracer: Tracer):
13+
def traced_method(wrapped, instance, args, kwargs):
14+
span_name = api["SPAN_NAME"]
15+
operation = api["OPERATION"]
16+
with tracer.start_as_current_span(span_name, kind=SpanKind.CLIENT) as span:
17+
try:
18+
span_attributes = {
19+
"db.system": "milvus",
20+
"db.operation": operation,
21+
"db.name": kwargs.get("collection_name", None),
22+
**get_extra_attributes(),
23+
}
24+
25+
if operation == "create_collection":
26+
set_create_collection_attributes(span_attributes, kwargs)
27+
28+
elif operation == "insert" or operation == "upsert":
29+
set_insert_or_upsert_attributes(span_attributes, kwargs)
30+
31+
elif operation == "search":
32+
set_search_attributes(span_attributes, kwargs)
33+
34+
elif operation == "query":
35+
set_query_attributes(span_attributes, kwargs)
36+
37+
set_span_attributes(span, span_attributes)
38+
result = wrapped(*args, **kwargs)
39+
40+
if operation == "query":
41+
set_query_response_attributes(span, result)
42+
43+
if operation == "search":
44+
set_search_response_attributes(span, result)
45+
return result
46+
except Exception as err:
47+
handle_span_error(span, err)
48+
raise
49+
50+
return traced_method
51+
52+
53+
@silently_fail
54+
def set_create_collection_attributes(span_attributes, kwargs):
55+
span_attributes["db.dimension"] = kwargs.get("dimension", None)
56+
57+
58+
@silently_fail
59+
def set_insert_or_upsert_attributes(span_attributes, kwargs):
60+
data = kwargs.get("data")
61+
timeout = kwargs.get("timeout")
62+
partition_name = kwargs.get("partition_name")
63+
64+
span_attributes["db.num_entities"] = len(data) if data else None
65+
span_attributes["db.timeout"] = timeout
66+
span_attributes["db.partition_name"] = partition_name
67+
68+
69+
@silently_fail
70+
def set_search_attributes(span_attributes, kwargs):
71+
data = kwargs.get("data")
72+
filter = kwargs.get("filter")
73+
limit = kwargs.get("limit")
74+
output_fields = kwargs.get("output_fields")
75+
search_params = kwargs.get("search_params")
76+
timeout = kwargs.get("timeout")
77+
partition_names = kwargs.get("partition_names")
78+
anns_field = kwargs.get("anns_field")
79+
span_attributes["db.num_queries"] = len(data) if data else None
80+
span_attributes["db.filter"] = filter
81+
span_attributes["db.limit"] = limit
82+
span_attributes["db.output_fields"] = json.dumps(output_fields)
83+
span_attributes["db.search_params"] = json.dumps(search_params)
84+
span_attributes["db.partition_names"] = json.dumps(partition_names)
85+
span_attributes["db.anns_field"] = anns_field
86+
span_attributes["db.timeout"] = timeout
87+
88+
89+
@silently_fail
90+
def set_query_attributes(span_attributes, kwargs):
91+
filter = kwargs.get("filter")
92+
output_fields = kwargs.get("output_fields")
93+
timeout = kwargs.get("timeout")
94+
partition_names = kwargs.get("partition_names")
95+
ids = kwargs.get("ids")
96+
97+
span_attributes["db.filter"] = filter
98+
span_attributes["db.output_fields"] = output_fields
99+
span_attributes["db.timeout"] = timeout
100+
span_attributes["db.partition_names"] = partition_names
101+
span_attributes["db.ids"] = ids
102+
103+
104+
@silently_fail
105+
def set_query_response_attributes(span, result):
106+
set_span_attribute(span, name="db.num_matches", value=len(result))
107+
for match in result:
108+
span.add_event(
109+
"db.query.match",
110+
attributes=match,
111+
)
112+
113+
114+
@silently_fail
115+
def set_search_response_attributes(span, result):
116+
for res in result:
117+
for match in res:
118+
span.add_event(
119+
"db.search.match",
120+
attributes={
121+
"id": match["id"],
122+
"distance": str(match["distance"]),
123+
"entity": json.dumps(match["entity"]),
124+
},
125+
)

src/langtrace_python_sdk/instrumentation/pymongo/patch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ def traced_method(wrapped, instance, args, kwargs):
3838

3939
try:
4040
result = wrapped(*args, **kwargs)
41-
print(result)
4241
for doc in result:
4342
if span.is_recording():
4443
span.add_event(

0 commit comments

Comments
 (0)