Skip to content

Commit b25d659

Browse files
committed
Merge branch 'main' into development
2 parents 977fc78 + c024295 commit b25d659

File tree

7 files changed

+164
-46
lines changed

7 files changed

+164
-46
lines changed
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import json
2+
from typing import Dict
3+
4+
import boto3
5+
from dotenv import load_dotenv
6+
from langchain.chains.question_answering import load_qa_chain
7+
from langchain_community.llms.sagemaker_endpoint import (LLMContentHandler,
8+
SagemakerEndpoint)
9+
from langchain_core.documents import Document
10+
from langchain_core.prompts import PromptTemplate
11+
12+
from langtrace_python_sdk import langtrace, with_langtrace_root_span
13+
14+
# Add the path to the root of the project to the sys.path
15+
16+
load_dotenv()
17+
18+
langtrace.init()
19+
example_doc_1 = """
20+
Peter and Elizabeth took a taxi to attend the night party in the city. While in the party, Elizabeth collapsed and was rushed to the hospital.
21+
Since she was diagnosed with a brain injury, the doctor told Peter to stay besides her until she gets well.
22+
Therefore, Peter stayed with her at the hospital for 3 days without leaving.
23+
"""
24+
25+
docs = [
26+
Document(
27+
page_content=example_doc_1,
28+
)
29+
]
30+
31+
32+
query = """How long was Elizabeth hospitalized?"""
33+
prompt_template = """Use the following pieces of context to answer the question at the end.
34+
35+
{context}
36+
37+
Question: {question}
38+
Answer:"""
39+
PROMPT = PromptTemplate(
40+
template=prompt_template, input_variables=["context", "question"]
41+
)
42+
43+
44+
client = boto3.client(
45+
"sagemaker-runtime",
46+
region_name="us-east-1",
47+
)
48+
49+
50+
class ContentHandler(LLMContentHandler):
51+
content_type = "application/json"
52+
accepts = "application/json"
53+
54+
def transform_input(self, prompt: str, model_kwargs: Dict) -> bytes:
55+
input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})
56+
return input_str.encode("utf-8")
57+
58+
def transform_output(self, output: bytes) -> str:
59+
response_json = json.loads(output.read().decode("utf-8"))
60+
return response_json["generated_text"]
61+
62+
63+
@with_langtrace_root_span("SagemakerEndpoint")
64+
def main():
65+
content_handler = ContentHandler()
66+
67+
chain = load_qa_chain(
68+
llm=SagemakerEndpoint(
69+
endpoint_name="jumpstart-dft-meta-textgeneration-l-20240809-083223",
70+
client=client,
71+
model_kwargs={"temperature": 1e-10},
72+
content_handler=content_handler,
73+
),
74+
prompt=PROMPT,
75+
)
76+
77+
res = chain({"input_documents": docs, "question": query}, return_only_outputs=True)
78+
print(res)
79+
80+
81+
main()

src/langtrace_python_sdk/instrumentation/langchain_community/instrumentation.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,8 @@ def patch_module_classes(
4949
lambda member: inspect.isclass(member) and member.__module__ == module.__name__,
5050
):
5151
# loop through all public methods of the class
52-
for method_name, _ in inspect.getmembers(obj, predicate=inspect.isfunction):
53-
# Skip private methods
54-
if method_name.startswith("_"):
52+
for method_name, method in inspect.getmembers(obj, predicate=inspect.isfunction):
53+
if method.__qualname__.split('.')[0] != name:
5554
continue
5655
try:
5756
method_path = f"{name}.{method_name}"
@@ -82,6 +81,12 @@ def _instrument(self, **kwargs):
8281

8382
# List of modules to patch, with their corresponding patch names
8483
modules_to_patch = [
84+
(
85+
"langchain_community.llms.sagemaker_endpoint",
86+
"sagemaker_endpoint",
87+
True,
88+
True,
89+
),
8590
("langchain_community.document_loaders.pdf", "load_pdf", True, True),
8691
("langchain_community.vectorstores.faiss", "vector_store", False, False),
8792
("langchain_community.vectorstores.pgvector", "vector_store", False, False),

src/langtrace_python_sdk/instrumentation/langchain_community/patch.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def traced_method(wrapped, instance, args, kwargs):
5050
**(extra_attributes if extra_attributes is not None else {}),
5151
}
5252

53+
span_attributes["langchain.metadata"] = to_json_string(kwargs)
54+
5355
if trace_input and len(args) > 0:
5456
span_attributes["langchain.inputs"] = to_json_string(args)
5557

@@ -86,15 +88,31 @@ def traced_method(wrapped, instance, args, kwargs):
8688

8789
def clean_empty(d):
8890
"""Recursively remove empty lists, empty dicts, or None elements from a dictionary."""
89-
if not isinstance(d, (dict, list)):
91+
if not isinstance(d, (dict, list, tuple)):
9092
return d
93+
if isinstance(d, tuple):
94+
return tuple(val for val in (clean_empty(val) for val in d) if val != () and val is not None)
9195
if isinstance(d, list):
92-
return [v for v in (clean_empty(v) for v in d) if v != [] and v is not None]
93-
return {
94-
k: v
95-
for k, v in ((k, clean_empty(v)) for k, v in d.items())
96-
if v is not None and v != {}
97-
}
96+
return [val for val in (clean_empty(val) for val in d) if val != [] and val is not None]
97+
result = {}
98+
for k, val in d.items():
99+
if isinstance(val, dict):
100+
val = clean_empty(val)
101+
if val != {} and val is not None:
102+
result[k] = val
103+
elif isinstance(val, list):
104+
val = [clean_empty(value) for value in val]
105+
if val != [] and val is not None:
106+
result[k] = val
107+
elif isinstance(val, str) and val is not None:
108+
if val.strip() != "":
109+
result[k] = val.strip()
110+
elif isinstance(val, object):
111+
# some langchain objects have a text attribute
112+
val = getattr(val, 'text', None)
113+
if val is not None and val.strip() != "":
114+
result[k] = val.strip()
115+
return result
98116

99117

100118
def custom_serializer(obj):
@@ -109,5 +127,12 @@ def custom_serializer(obj):
109127

110128
def to_json_string(any_object):
111129
"""Converts any object to a JSON-parseable string, omitting empty or None values."""
112-
cleaned_object = clean_empty(any_object)
113-
return json.dumps(cleaned_object, default=custom_serializer, indent=2)
130+
try:
131+
cleaned_object = clean_empty(any_object)
132+
return json.dumps(cleaned_object, default=custom_serializer, indent=2)
133+
except NotImplementedError:
134+
# Handle specific types that raise this error
135+
return str(any_object) # or another appropriate fallback
136+
except TypeError:
137+
# Handle cases where obj is not serializable
138+
return str(any_object)

src/langtrace_python_sdk/instrumentation/langchain_core/instrumentation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,8 @@ def patch_module_classes(
6666
if name.startswith("_") or name in exclude_classes:
6767
continue
6868
# loop through all public methods of the class
69-
for method_name, _ in inspect.getmembers(obj, predicate=inspect.isfunction):
70-
# Skip private methods
71-
if method_name.startswith("_") or method_name in exclude_methods:
69+
for method_name, method in inspect.getmembers(obj, predicate=inspect.isfunction):
70+
if method_name in exclude_methods or method.__qualname__.split('.')[0] != name:
7271
continue
7372
try:
7473
method_path = f"{name}.{method_name}"
@@ -126,6 +125,7 @@ def _instrument(self, **kwargs):
126125
modules_to_patch = [
127126
("langchain_core.retrievers", "retriever", generic_patch, True, True),
128127
("langchain_core.prompts.chat", "prompt", generic_patch, True, True),
128+
("langchain_core.language_models.llms", "generate", generic_patch, True, True),
129129
("langchain_core.runnables.base", "runnable", runnable_patch, True, True),
130130
(
131131
"langchain_core.runnables.passthrough",

src/langtrace_python_sdk/instrumentation/langchain_core/patch.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -59,19 +59,10 @@ def traced_method(wrapped, instance, args, kwargs):
5959
**(extra_attributes if extra_attributes is not None else {}),
6060
}
6161

62-
if len(args) > 0 and trace_input:
63-
inputs = {}
64-
for arg in args:
65-
if isinstance(arg, dict):
66-
for key, value in arg.items():
67-
if isinstance(value, list):
68-
for item in value:
69-
inputs[key] = item.__class__.__name__
70-
elif isinstance(value, str):
71-
inputs[key] = value
72-
elif isinstance(arg, str):
73-
inputs["input"] = arg
74-
span_attributes["langchain.inputs"] = to_json_string(inputs)
62+
if trace_input and len(args) > 0:
63+
span_attributes["langchain.inputs"] = to_json_string(args)
64+
65+
span_attributes["langchain.metadata"] = to_json_string(kwargs)
7566

7667
attributes = FrameworkSpanAttributes(**span_attributes)
7768

@@ -200,15 +191,31 @@ def traced_method(wrapped, instance, args, kwargs):
200191

201192
def clean_empty(d):
202193
"""Recursively remove empty lists, empty dicts, or None elements from a dictionary."""
203-
if not isinstance(d, (dict, list)):
194+
if not isinstance(d, (dict, list, tuple)):
204195
return d
196+
if isinstance(d, tuple):
197+
return tuple(val for val in (clean_empty(val) for val in d) if val != () and val is not None)
205198
if isinstance(d, list):
206-
return [v for v in (clean_empty(v) for v in d) if v != [] and v is not None]
207-
return {
208-
k: v
209-
for k, v in ((k, clean_empty(v)) for k, v in d.items())
210-
if v is not None and v != {}
211-
}
199+
return [val for val in (clean_empty(val) for val in d) if val != [] and val is not None]
200+
result = {}
201+
for k, val in d.items():
202+
if isinstance(val, dict):
203+
val = clean_empty(val)
204+
if val != {} and val is not None:
205+
result[k] = val
206+
elif isinstance(val, list):
207+
val = [clean_empty(value) for value in val]
208+
if val != [] and val is not None:
209+
result[k] = val
210+
elif isinstance(val, str) and val is not None:
211+
if val.strip() != "":
212+
result[k] = val.strip()
213+
elif isinstance(val, object):
214+
# some langchain objects have a text attribute
215+
val = getattr(val, 'text', None)
216+
if val is not None and val.strip() != "":
217+
result[k] = val.strip()
218+
return result
212219

213220

214221
def custom_serializer(obj):
@@ -223,5 +230,12 @@ def custom_serializer(obj):
223230

224231
def to_json_string(any_object):
225232
"""Converts any object to a JSON-parseable string, omitting empty or None values."""
226-
cleaned_object = clean_empty(any_object)
227-
return json.dumps(cleaned_object, default=custom_serializer, indent=2)
233+
try:
234+
cleaned_object = clean_empty(any_object)
235+
return json.dumps(cleaned_object, default=custom_serializer, indent=2)
236+
except NotImplementedError:
237+
# Handle specific types that raise this error
238+
return str(any_object) # or another appropriate fallback
239+
except TypeError:
240+
# Handle cases where obj is not serializable
241+
return str(any_object)
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "2.2.18"
1+
__version__ = "2.2.19"

src/tests/langchain/test_langchain.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,4 @@ def test_langchain(exporter):
1818
chain.invoke({"input": "how can langsmith help with testing?"})
1919
spans = exporter.get_finished_spans()
2020

21-
assert [
22-
"ChatPromptTemplate.invoke",
23-
"openai.chat.completions.create",
24-
"StrOutputParser.parse",
25-
"StrOutputParser.parse_result",
26-
"StrOutputParser.invoke",
27-
"RunnableSequence.invoke",
28-
] == [span.name for span in spans]
21+
assert len(spans) > 0

0 commit comments

Comments
 (0)