Skip to content

Commit a87a6da

Browse files
committed
Merge branch 'development' of github.com:Scale3-Labs/langtrace-python-sdk into development
2 parents 698d2ff + 318c215 commit a87a6da

File tree

7 files changed

+194
-27
lines changed

7 files changed

+194
-27
lines changed
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
from typing import TypedDict, Union, Annotated
2+
from langchain_core.agents import AgentAction, AgentFinish
3+
from langchain_core.tools import tool
4+
import operator
5+
from dotenv import load_dotenv
6+
from langchain_openai import ChatOpenAI
7+
8+
from langchain import hub
9+
from langchain.agents import create_openai_tools_agent
10+
import json
11+
from langgraph.graph import StateGraph, END
12+
from langtrace_python_sdk import langtrace, with_langtrace_root_span
13+
14+
load_dotenv()
15+
16+
langtrace.init(write_spans_to_console=False)
17+
18+
19+
class AgentState(TypedDict):
20+
input: str
21+
agent_out: Union[AgentAction, AgentFinish, None]
22+
intermediate_steps: Annotated[list[tuple[AgentAction, str]], operator.add]
23+
24+
25+
ehi_information = """Title: EHI: End-to-end Learning of Hierarchical Index for
26+
Efficient Dense Retrieval
27+
Summary: Dense embedding-based retrieval is now the industry
28+
standard for semantic search and ranking problems, like obtaining relevant web
29+
documents for a given query. Such techniques use a two-stage process: (a)
30+
contrastive learning to train a dual encoder to embed both the query and
31+
documents and (b) approximate nearest neighbor search (ANNS) for finding similar
32+
documents for a given query. These two stages are disjoint; the learned
33+
embeddings might be ill-suited for the ANNS method and vice-versa, leading to
34+
suboptimal performance. In this work, we propose End-to-end Hierarchical
35+
Indexing -- EHI -- that jointly learns both the embeddings and the ANNS
36+
structure to optimize retrieval performance. EHI uses a standard dual encoder
37+
model for embedding queries and documents while learning an inverted file index
38+
(IVF) style tree structure for efficient ANNS. To ensure stable and efficient
39+
learning of discrete tree-based ANNS structure, EHI introduces the notion of
40+
dense path embedding that captures the position of a query/document in the tree.
41+
We demonstrate the effectiveness of EHI on several benchmarks, including
42+
de-facto industry standard MS MARCO (Dev set and TREC DL19) datasets. For
43+
example, with the same compute budget, EHI outperforms state-of-the-art (SOTA)
44+
in by 0.6% (MRR@10) on MS MARCO dev set and by 4.2% (nDCG@10) on TREC DL19
45+
benchmarks.
46+
Author(s): Ramnath Kumar, Anshul Mittal, Nilesh Gupta, Aditya Kusupati,
47+
Inderjit Dhillon, Prateek Jain
48+
Source: https://arxiv.org/pdf/2310.08891.pdf"""
49+
50+
51+
@tool("search")
52+
def search_tool(query: str):
53+
"""Searches for information on the topic of artificial intelligence (AI).
54+
Cannot be used to research any other topics. Search query must be provided
55+
in natural language and be verbose."""
56+
# this is a "RAG" emulator
57+
return ehi_information
58+
59+
60+
@tool("final_answer")
61+
def final_answer_tool(answer: str, source: str):
62+
"""Returns a natural language response to the user in `answer`, and a
63+
`source` which provides citations for where this information came from.
64+
"""
65+
return ""
66+
67+
68+
llm = ChatOpenAI()
69+
prompt = hub.pull("hwchase17/openai-functions-agent")
70+
71+
72+
query_agent_runnable = create_openai_tools_agent(
73+
llm=llm, tools=[final_answer_tool, search_tool], prompt=prompt
74+
)
75+
76+
77+
inputs = {"input": "what are EHI embeddings?", "intermediate_steps": []}
78+
79+
agent_out = query_agent_runnable.invoke(inputs)
80+
81+
82+
def run_query_agent(state: list):
83+
print("> run_query_agent")
84+
agent_out = query_agent_runnable.invoke(state)
85+
return {"agent_out": agent_out}
86+
87+
88+
def execute_search(state: list):
89+
print("> execute_search")
90+
action = state["agent_out"]
91+
tool_call = action[-1].message_log[-1].additional_kwargs["tool_calls"][-1]
92+
out = search_tool.invoke(json.loads(tool_call["function"]["arguments"]))
93+
return {"intermediate_steps": [{"search": str(out)}]}
94+
95+
96+
def router(state: list):
97+
print("> router")
98+
if isinstance(state["agent_out"], list):
99+
return state["agent_out"][-1].tool
100+
else:
101+
return "error"
102+
103+
104+
# finally, we will have a single LLM call that MUST use the final_answer structure
105+
final_answer_llm = llm.bind_tools([final_answer_tool], tool_choice="final_answer")
106+
107+
108+
# this forced final_answer LLM call will be used to structure output from our
109+
# RAG endpoint
110+
def rag_final_answer(state: list):
111+
print("> final_answer")
112+
query = state["input"]
113+
context = state["intermediate_steps"][-1]
114+
115+
prompt = f"""You are a helpful assistant, answer the user's question using the
116+
context provided.
117+
118+
CONTEXT: {context}
119+
120+
QUESTION: {query}
121+
"""
122+
out = final_answer_llm.invoke(prompt)
123+
function_call = out.additional_kwargs["tool_calls"][-1]["function"]["arguments"]
124+
return {"agent_out": function_call}
125+
126+
127+
# we use the same forced final_answer LLM call to handle incorrectly formatted
128+
# output from our query_agent
129+
def handle_error(state: list):
130+
print("> handle_error")
131+
query = state["input"]
132+
prompt = f"""You are a helpful assistant, answer the user's question.
133+
134+
QUESTION: {query}
135+
"""
136+
out = final_answer_llm.invoke(prompt)
137+
function_call = out.additional_kwargs["tool_calls"][-1]["function"]["arguments"]
138+
return {"agent_out": function_call}
139+
140+
141+
@with_langtrace_root_span("run_graph")
142+
def run_graph():
143+
graph = StateGraph(AgentState)
144+
145+
# we have four nodes that will consume our agent state and modify
146+
# our agent state based on some internal process
147+
graph.add_node("query_agent", run_query_agent)
148+
graph.add_node("search", execute_search)
149+
graph.add_node("error", handle_error)
150+
graph.add_node("rag_final_answer", rag_final_answer)
151+
# our graph will always begin with the query agent
152+
graph.set_entry_point("query_agent")
153+
# conditional edges are controlled by our router
154+
graph.add_conditional_edges(
155+
"query_agent",
156+
router,
157+
{
158+
"search": "search",
159+
"error": "error",
160+
"final_answer": END,
161+
},
162+
)
163+
graph.add_edge("search", "rag_final_answer")
164+
graph.add_edge("error", END)
165+
graph.add_edge("rag_final_answer", END)
166+
167+
runnable = graph.compile()
168+
169+
return runnable.invoke({"input": "what are EHI embeddings?"})
170+
171+
172+
if __name__ == "__main__":
173+
run_graph()

src/langtrace_python_sdk/instrumentation/crewai/patch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def _parse_tasks(self, tasks):
223223
for task in tasks:
224224
self.crew["tasks"].append(
225225
{
226-
"agent": task.agent.role,
226+
"agent": task.agent.role if task.agent else None,
227227
"description": task.description,
228228
"async_execution": task.async_execution,
229229
"expected_output": task.expected_output,

src/langtrace_python_sdk/instrumentation/dspy/instrumentation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@ class DspyInstrumentation(BaseInstrumentor):
2727
The DspyInstrumentor class represents the DSPy instrumentation"""
2828

2929
def instrumentation_dependencies(self) -> Collection[str]:
30-
return ["dspy-ai >= 2.0.0"]
30+
return ["dspy >= 2.0.0"]
3131

3232
def _instrument(self, **kwargs):
3333
tracer_provider = kwargs.get("tracer_provider")
3434
tracer = get_tracer(__name__, "", tracer_provider)
35-
version = v("dspy-ai")
35+
version = v("dspy")
3636
_W(
3737
"dspy.teleprompt.bootstrap",
3838
"BootstrapFewShot.compile",

src/langtrace_python_sdk/instrumentation/langgraph/instrumentation.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -41,34 +41,29 @@ def _instrument(self, **kwargs):
4141
# List of modules to patch, with their corresponding patch names
4242
modules_to_patch = [
4343
(
44-
"langgraph.graph.graph",
44+
"langgraph.graph.state", # Updated module path
45+
"StateGraph", # Updated class name
4546
[
4647
"add_node",
4748
"add_edge",
4849
"set_entry_point",
4950
"set_finish_point",
5051
"add_conditional_edges",
5152
],
52-
),
53+
)
5354
]
5455

55-
for module_name, methods in modules_to_patch:
56-
module = importlib.import_module(module_name)
57-
for name, obj in inspect.getmembers(
58-
module,
59-
lambda member: inspect.isclass(member)
60-
and member.__module__ == module.__name__,
61-
):
62-
for method_name, _ in inspect.getmembers(
63-
obj, predicate=inspect.isfunction
64-
):
65-
if method_name in methods:
66-
module = f"{name}.{method_name}"
67-
wrap_function_wrapper(
68-
module_name,
69-
module,
70-
patch_graph_methods(module, tracer, version),
71-
)
56+
for module_name, class_name, methods in modules_to_patch:
57+
for method_name in methods:
58+
# Construct the correct path for the method
59+
method_path = f"{class_name}.{method_name}"
60+
wrap_function_wrapper(
61+
module_name,
62+
method_path,
63+
patch_graph_methods(
64+
f"{module_name}.{method_path}", tracer, version
65+
),
66+
)
7267

7368
def _uninstrument(self, **kwargs):
7469
pass

src/langtrace_python_sdk/instrumentation/langgraph/patch.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from importlib_metadata import version as v
3131

3232
from langtrace_python_sdk.constants import LANGTRACE_SDK_NAME
33+
from langtrace_python_sdk.utils.llm import set_span_attributes
3334

3435

3536
def patch_graph_methods(method_name, tracer, version):
@@ -57,9 +58,7 @@ def traced_method(wrapped, instance, args, kwargs):
5758
kind=SpanKind.CLIENT,
5859
context=set_span_in_context(trace.get_current_span()),
5960
) as span:
60-
for field, value in attributes.model_dump(by_alias=True).items():
61-
if value is not None:
62-
span.set_attribute(field, value)
61+
set_span_attributes(span, attributes)
6362
try:
6463
# Attempt to call the original method
6564
result = wrapped(*args, **kwargs)

src/langtrace_python_sdk/langtrace.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def init(
275275
"weaviate-client": WeaviateInstrumentation(),
276276
"sqlalchemy": SQLAlchemyInstrumentor(),
277277
"ollama": OllamaInstrumentor(),
278-
"dspy-ai": DspyInstrumentation(),
278+
"dspy": DspyInstrumentation(),
279279
"crewai": CrewAIInstrumentation(),
280280
"vertexai": VertexAIInstrumentation(),
281281
"google-cloud-aiplatform": VertexAIInstrumentation(),
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "3.3.11"
1+
__version__ = "3.3.13"

0 commit comments

Comments
 (0)