Skip to content

Commit 4573db0

Browse files
committed
milvus kickstart
1 parent 4988dee commit 4573db0

File tree

9 files changed

+279
-1
lines changed

9 files changed

+279
-1
lines changed
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from pymilvus import MilvusClient, model
2+
from typing import List
3+
from langtrace_python_sdk import langtrace, with_langtrace_root_span
4+
from dotenv import load_dotenv
5+
6+
load_dotenv()
7+
langtrace.init(write_spans_to_console=False)
8+
9+
client = MilvusClient("milvus_demo.db")
10+
11+
COLLECTION_NAME = "demo_collection"
12+
embedding_fn = model.DefaultEmbeddingFunction()
13+
14+
15+
def create_collection(collection_name: str = COLLECTION_NAME):
16+
if client.has_collection(collection_name=collection_name):
17+
client.drop_collection(collection_name=collection_name)
18+
19+
client.create_collection(
20+
collection_name=collection_name,
21+
dimension=768, # The vectors we will use in this demo has 768 dimensions
22+
)
23+
24+
25+
def create_embedding(docs: List[str] = [], subject: str = "history"):
26+
"""
27+
Create embeddings for the given documents.
28+
"""
29+
30+
vectors = embedding_fn.encode_documents(docs)
31+
# Each entity has id, vector representation, raw text, and a subject label that we use
32+
# to demo metadata filtering later.
33+
data = [
34+
{"id": i, "vector": vectors[i], "text": docs[i], "subject": subject}
35+
for i in range(len(vectors))
36+
]
37+
# print("Data has", len(data), "entities, each with fields: ", data[0].keys())
38+
# print("Vector dim:", len(data[0]["vector"]))
39+
return data
40+
41+
42+
def insert_data(collection_name: str = COLLECTION_NAME, data: List[dict] = []):
43+
client.insert(
44+
collection_name=collection_name,
45+
data=data,
46+
)
47+
48+
49+
def vector_search(collection_name: str = COLLECTION_NAME, queries: List[str] = []):
50+
query_vectors = embedding_fn.encode_queries(queries)
51+
# If you don't have the embedding function you can use a fake vector to finish the demo:
52+
# query_vectors = [ [ random.uniform(-1, 1) for _ in range(768) ] ]
53+
54+
res = client.search(
55+
collection_name="demo_collection", # target collection
56+
data=query_vectors, # query vectors
57+
limit=2, # number of returned entities
58+
output_fields=["text", "subject"], # specifies fields to be returned
59+
)
60+
61+
62+
def query(collection_name: str = COLLECTION_NAME, query: str = ""):
63+
res = client.query(
64+
collection_name=collection_name,
65+
filter=query,
66+
# output_fields=["text", "subject"],
67+
)
68+
69+
# print(res)
70+
71+
72+
@with_langtrace_root_span("milvus_example")
73+
def main():
74+
create_collection()
75+
# insert Alan Turing's history
76+
turing_data = create_embedding(
77+
docs=[
78+
"Artificial intelligence was founded as an academic discipline in 1956.",
79+
"Alan Turing was the first person to conduct substantial research in AI.",
80+
"Born in Maida Vale, London, Turing was raised in southern England.",
81+
]
82+
)
83+
insert_data(data=turing_data)
84+
85+
# insert AI Drug Discovery
86+
drug_data = create_embedding(
87+
docs=[
88+
"Machine learning has been used for drug design.",
89+
"Computational synthesis with AI algorithms predicts molecular properties.",
90+
"DDR1 is involved in cancers and fibrosis.",
91+
],
92+
subject="biology",
93+
)
94+
insert_data(data=drug_data)
95+
96+
vector_search(queries=["Who is Alan Turing?"])
97+
query(query="subject == 'history'")
98+
query(query="subject == 'biology'")
99+
100+
101+
if __name__ == "__main__":
102+
main()

src/langtrace_python_sdk/constants/instrumentation/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
"MONGODB": "MongoDB",
3838
"AWS_BEDROCK": "AWS Bedrock",
3939
"CEREBRAS": "Cerebras",
40+
"MILVUS": "Milvus",
4041
}
4142

4243
LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY = "langtrace_additional_attributes"
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
APIS = {
2+
"INSERT": {
3+
"MODULE": "pymilvus",
4+
"METHOD": "MilvusClient.insert",
5+
"OPERATION": "insert",
6+
"SPAN_NAME": "Milvus Insert",
7+
},
8+
"QUERY": {
9+
"MODULE": "pymilvus",
10+
"METHOD": "MilvusClient.query",
11+
"OPERATION": "query",
12+
"SPAN_NAME": "Milvus Query",
13+
},
14+
"SEARCH": {
15+
"MODULE": "pymilvus",
16+
"METHOD": "MilvusClient.search",
17+
"OPERATION": "search",
18+
"SPAN_NAME": "Milvus Search",
19+
},
20+
"DELETE": {
21+
"MODULE": "pymilvus",
22+
"METHOD": "MilvusClient.delete",
23+
"OPERATION": "delete",
24+
"SPAN_NAME": "Milvus Delete",
25+
},
26+
"CREATE_COLLECTION": {
27+
"MODULE": "pymilvus",
28+
"METHOD": "MilvusClient.create_collection",
29+
"OPERATION": "create_collection",
30+
"SPAN_NAME": "Milvus Create Collection",
31+
},
32+
"UPSERT": {
33+
"MODULE": "pymilvus",
34+
"METHOD": "MilvusClient.upsert",
35+
"OPERATION": "upsert",
36+
"SPAN_NAME": "Milvus Upsert",
37+
},
38+
}

src/langtrace_python_sdk/instrumentation/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from .litellm import LiteLLMInstrumentation
2424
from .pymongo import PyMongoInstrumentation
2525
from .cerebras import CerebrasInstrumentation
26+
from .milvus import MilvusInstrumentation
2627

2728
__all__ = [
2829
"AnthropicInstrumentation",
@@ -50,4 +51,5 @@
5051
"PyMongoInstrumentation",
5152
"AWSBedrockInstrumentation",
5253
"CerebrasInstrumentation",
54+
"MilvusInstrumentation",
5355
]
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .instrumentation import MilvusInstrumentation
2+
3+
__all__ = ["MilvusInstrumentation"]
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
2+
from opentelemetry.trace import get_tracer
3+
4+
from typing import Collection
5+
from importlib_metadata import version as v
6+
from wrapt import wrap_function_wrapper as _W
7+
8+
from langtrace_python_sdk.constants.instrumentation.milvus import APIS
9+
from .patch import generic_patch
10+
11+
12+
class MilvusInstrumentation(BaseInstrumentor):
13+
14+
def instrumentation_dependencies(self) -> Collection[str]:
15+
return ["pymilvus >= 2.4.1"]
16+
17+
def _instrument(self, **kwargs):
18+
tracer_provider = kwargs.get("tracer_provider")
19+
tracer = get_tracer(__name__, "", tracer_provider)
20+
version = v("pymilvus")
21+
for api in APIS.values():
22+
_W(
23+
module=api["MODULE"],
24+
name=api["METHOD"],
25+
wrapper=generic_patch(api, version, tracer),
26+
)
27+
28+
def _uninstrument(self, **kwargs):
29+
pass
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from opentelemetry.trace import Tracer
2+
from opentelemetry.trace import SpanKind
3+
from langtrace_python_sdk.utils import handle_span_error
4+
from langtrace_python_sdk.utils.llm import (
5+
get_extra_attributes,
6+
set_span_attributes,
7+
)
8+
9+
10+
def generic_patch(api, version: str, tracer: Tracer):
11+
def traced_method(wrapped, instance, args, kwargs):
12+
span_name = api["SPAN_NAME"]
13+
operation = api["OPERATION"]
14+
with tracer.start_as_current_span(span_name, kind=SpanKind.CLIENT) as span:
15+
try:
16+
span_attributes = {
17+
"db.system": "milvus",
18+
"db.operation": operation,
19+
"db.name": kwargs.get("collection_name", None),
20+
**get_extra_attributes(),
21+
}
22+
23+
if operation == "create_collection":
24+
set_create_collection_attributes(span_attributes, kwargs)
25+
26+
elif operation == "insert" or operation == "upsert":
27+
set_insert_or_upsert_attributes(span_attributes, kwargs)
28+
29+
elif operation == "search":
30+
set_search_attributes(span_attributes, kwargs)
31+
32+
elif operation == "query":
33+
set_query_attributes(span_attributes, kwargs)
34+
35+
set_span_attributes(span, span_attributes)
36+
result = wrapped(*args, **kwargs)
37+
38+
if operation == "query":
39+
set_query_response_attributes(span, result)
40+
41+
return result
42+
except Exception as err:
43+
handle_span_error(span, err)
44+
raise
45+
46+
return traced_method
47+
48+
49+
def set_create_collection_attributes(span_attributes, kwargs):
50+
span_attributes["db.dimension"] = kwargs.get("dimension", None)
51+
52+
53+
def set_insert_or_upsert_attributes(span_attributes, kwargs):
54+
data = kwargs.get("data")
55+
timeout = kwargs.get("timeout")
56+
partition_name = kwargs.get("partition_name")
57+
58+
span_attributes["db.num_entities"] = len(data) if data else None
59+
span_attributes["db.timeout"] = timeout
60+
span_attributes["db.partition_name"] = partition_name
61+
62+
63+
def set_search_attributes(span_attributes, kwargs):
64+
data = kwargs.get("data")
65+
filter = kwargs.get("filter")
66+
limit = kwargs.get("limit")
67+
output_fields = kwargs.get("output_fields")
68+
search_params = kwargs.get("search_params")
69+
timeout = kwargs.get("timeout")
70+
partition_names = kwargs.get("partition_names")
71+
anns_field = kwargs.get("anns_field")
72+
73+
span_attributes["db.num_queries"] = len(data) if data else None
74+
span_attributes["db.filter"] = filter
75+
span_attributes["db.limit"] = limit
76+
span_attributes["db.output_fields"] = output_fields
77+
span_attributes["db.search_params"] = search_params
78+
span_attributes["db.partition_names"] = partition_names
79+
span_attributes["db.anns_field"] = anns_field
80+
span_attributes["db.timeout"] = timeout
81+
82+
83+
def set_query_attributes(span_attributes, kwargs):
84+
filter = kwargs.get("filter")
85+
output_fields = kwargs.get("output_fields")
86+
timeout = kwargs.get("timeout")
87+
partition_names = kwargs.get("partition_names")
88+
ids = kwargs.get("ids")
89+
90+
span_attributes["db.filter"] = filter
91+
span_attributes["db.output_fields"] = output_fields
92+
span_attributes["db.timeout"] = timeout
93+
span_attributes["db.partition_names"] = partition_names
94+
span_attributes["db.ids"] = ids
95+
96+
97+
def set_query_response_attributes(span, result):
98+
for match in result:
99+
span.add_event(
100+
"db.query.match",
101+
attributes=match,
102+
)

src/langtrace_python_sdk/instrumentation/pymongo/patch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ def traced_method(wrapped, instance, args, kwargs):
3838

3939
try:
4040
result = wrapped(*args, **kwargs)
41-
print(result)
4241
for doc in result:
4342
if span.is_recording():
4443
span.add_event(

src/langtrace_python_sdk/langtrace.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
WeaviateInstrumentation,
6767
PyMongoInstrumentation,
6868
CerebrasInstrumentation,
69+
MilvusInstrumentation,
6970
)
7071
from opentelemetry.util.re import parse_env_headers
7172

@@ -284,6 +285,7 @@ def init(
284285
"autogen": AutogenInstrumentation(),
285286
"pymongo": PyMongoInstrumentation(),
286287
"cerebras-cloud-sdk": CerebrasInstrumentation(),
288+
"pymilvus": MilvusInstrumentation(),
287289
}
288290

289291
init_instrumentations(config.disable_instrumentations, all_instrumentations)

0 commit comments

Comments
 (0)