Skip to content

Commit 1911ac4

Browse files
committed
feat: adding support for basic pgvector tracing
1 parent 587135a commit 1911ac4

File tree

9 files changed

+202
-1
lines changed

9 files changed

+202
-1
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ dev = [
5757
"google-cloud-aiplatform",
5858
"mistralai",
5959
"embedchain",
60+
"psycopg",
61+
"pgvector"
6062
]
6163

6264
test = ["pytest", "pytest-vcr", "pytest-asyncio"]
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from openai import OpenAI
2+
from pgvector.psycopg import register_vector
3+
import psycopg
4+
from langtrace_python_sdk import langtrace
5+
6+
langtrace.init(write_spans_to_console=True)
7+
8+
conn = psycopg.connect(dbname='postgres', autocommit=True, password="mypw", user="postgres", host="localhost")
9+
conn.execute('CREATE EXTENSION IF NOT EXISTS vector')
10+
register_vector(conn)
11+
12+
client = OpenAI()
13+
14+
15+
def setup_db():
16+
conn.execute('DROP TABLE IF EXISTS documents')
17+
conn.execute('CREATE TABLE documents (id bigserial PRIMARY KEY, content text, embedding vector(1536))')
18+
19+
input = [
20+
'The dog is barking',
21+
'The cat is purring',
22+
'The bear is growling'
23+
]
24+
25+
response = client.embeddings.create(input=input, model='text-embedding-3-small')
26+
embeddings = [v.embedding for v in response.data]
27+
28+
for content, embedding in zip(input, embeddings):
29+
conn.execute('INSERT INTO documents (content, embedding) VALUES (%s, %s)', (content, embedding))
30+
31+
32+
def basic_pgvector():
33+
setup_db()
34+
document_id = 1
35+
neighbors = conn.execute('SELECT content FROM documents WHERE id != %(id)s ORDER BY embedding <=> (SELECT embedding FROM documents WHERE id = %(id)s) LIMIT 5', {'id': document_id}).fetchall()
36+
for neighbor in neighbors:
37+
print(neighbor[0])
38+
39+
print("neighbors", neighbors)
40+
return neighbors
41+

src/langtrace_python_sdk/constants/instrumentation/common.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
"MISTRAL": "Mistral",
3333
"EMBEDCHAIN": "Embedchain",
3434
"AUTOGEN": "Autogen",
35+
"PGVECTOR": "PgVector",
3536
}
3637

3738
LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY = "langtrace_additional_attributes"

src/langtrace_python_sdk/instrumentation/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from .gemini import GeminiInstrumentation
2020
from .mistral import MistralInstrumentation
2121
from .embedchain import EmbedchainInstrumentation
22+
from .pgvector import PgVectorInstrumentation
2223

2324
__all__ = [
2425
"AnthropicInstrumentation",
@@ -42,4 +43,5 @@
4243
"VertexAIInstrumentation",
4344
"GeminiInstrumentation",
4445
"MistralInstrumentation",
46+
"PgVectorInstrumentation",
4547
]
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .instrumentation import PgVectorInstrumentation
2+
3+
__all__ = [
4+
"PgVectorInstrumentation",
5+
]
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""
2+
Copyright (c) 2024 Scale3 Labs
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import importlib.metadata
18+
import logging
19+
from typing import Collection
20+
21+
from opentelemetry.instrumentation.instrumentor import BaseInstrumentor
22+
from opentelemetry.trace import get_tracer
23+
from wrapt import wrap_function_wrapper
24+
25+
from langtrace_python_sdk.constants.instrumentation.pgvector import APIS
26+
from langtrace_python_sdk.instrumentation.pgvector.patch import (
27+
generic_patch,
28+
)
29+
30+
logging.basicConfig(level=logging.DEBUG) # Set to DEBUG for detailed logging
31+
32+
33+
class PgVectorInstrumentation(BaseInstrumentor):
34+
"""
35+
The PgVectorInstrumentation class represents the instrumentation for the Postgres Vector.
36+
"""
37+
38+
def instrumentation_dependencies(self) -> Collection[str]:
39+
return ["pgvector >= 0.3.3", "trace-attributes >= 7.1.0"]
40+
41+
def _instrument(self, **kwargs):
42+
tracer_provider = kwargs.get("tracer_provider")
43+
tracer = get_tracer(__name__, "", tracer_provider)
44+
version = importlib.metadata.version("pgvector")
45+
46+
print("Instrumenting PgVector")
47+
for api_name, api_config in APIS.items():
48+
wrap_function_wrapper(
49+
api_config["MODULE"],
50+
api_config["METHOD"],
51+
generic_patch(api_name, version, tracer),
52+
)
53+
54+
def _instrument_module(self, module_name):
55+
pass
56+
57+
def _uninstrument(self, **kwargs):
58+
pass
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
"""
2+
Copyright (c) 2024 Scale3 Labs
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import re
18+
from importlib_metadata import version as v
19+
from langtrace.trace_attributes import DatabaseSpanAttributes
20+
from opentelemetry import baggage, trace
21+
from opentelemetry.trace import SpanKind,Tracer
22+
from opentelemetry.trace.propagation import set_span_in_context
23+
from opentelemetry.trace.status import Status, StatusCode
24+
25+
from langtrace_python_sdk.constants import LANGTRACE_SDK_NAME
26+
from langtrace_python_sdk.constants.instrumentation.common import (
27+
LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY,
28+
SERVICE_PROVIDERS,
29+
)
30+
from langtrace_python_sdk.constants.instrumentation.pgvector import APIS
31+
from langtrace_python_sdk.utils.llm import get_span_name
32+
from langtrace_python_sdk.utils.misc import extract_input_params, to_iso_format
33+
34+
35+
# Utility function to extract table name
36+
def extract_table_name(query):
37+
# This regex assumes basic SQL queries like SELECT, INSERT INTO, UPDATE, DELETE FROM
38+
match = re.search(r'(from|into|update|delete\s+from)\s+(\w+)', query, re.IGNORECASE)
39+
if match:
40+
return match.group(2)
41+
return None
42+
43+
44+
def generic_patch(method_name, version, tracer: Tracer):
45+
46+
def traced_method(wrapped, instance, args, kwargs):
47+
query = args[0]
48+
params = args[1] if len(args) > 1 else None
49+
api = APIS[method_name]
50+
service_provider = SERVICE_PROVIDERS["PGVECTOR"]
51+
extra_attributes = baggage.get_baggage(LANGTRACE_ADDITIONAL_SPAN_ATTRIBUTES_KEY)
52+
53+
span_attributes = {
54+
"langtrace.sdk.name": "langtrace-python-sdk",
55+
"langtrace.service.name": service_provider,
56+
"langtrace.service.type": "vectordb",
57+
"langtrace.service.version": version,
58+
"langtrace.version": v(LANGTRACE_SDK_NAME),
59+
"db.system": "pgvector",
60+
"db.operation": api["OPERATION"],
61+
"db.collection.name": extract_table_name(query),
62+
"db.namespace": instance.connection.info.dbname,
63+
"db.query": str(query),
64+
**(extra_attributes if extra_attributes is not None else {}),
65+
}
66+
67+
attributes = DatabaseSpanAttributes(**span_attributes)
68+
69+
with tracer.start_as_current_span(
70+
name=get_span_name(method_name),
71+
kind=SpanKind.CLIENT,
72+
context=set_span_in_context(trace.get_current_span()),
73+
) as span:
74+
for field, value in attributes.model_dump(by_alias=True).items():
75+
if value is not None:
76+
span.set_attribute(field, value)
77+
try:
78+
# Attempt to call the original method
79+
result = wrapped(*args, **kwargs)
80+
span.set_status(StatusCode.OK)
81+
return result
82+
except Exception as err:
83+
# Record the exception in the span
84+
span.record_exception(err)
85+
# Set the span status to indicate an error
86+
span.set_status(Status(StatusCode.ERROR, str(err)))
87+
# Reraise the exception to ensure it's not swallowed
88+
raise
89+
90+
return traced_method

src/langtrace_python_sdk/langtrace.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
MistralInstrumentation,
5151
OllamaInstrumentor,
5252
OpenAIInstrumentation,
53+
PgVectorInstrumentation,
5354
PineconeInstrumentation,
5455
QdrantInstrumentation,
5556
AutogenInstrumentation,
@@ -148,6 +149,7 @@ def init(
148149
"google-generativeai": GeminiInstrumentation(),
149150
"mistralai": MistralInstrumentation(),
150151
"autogen": AutogenInstrumentation(),
152+
"pgvector": PgVectorInstrumentation(),
151153
}
152154

153155
init_instrumentations(disable_instrumentations, all_instrumentations)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "2.3.20"
1+
__version__ = "2.4.0"

0 commit comments

Comments
 (0)