|
| 1 | +from typing import TypedDict, Union, Annotated |
| 2 | +from langchain_core.agents import AgentAction, AgentFinish |
| 3 | +from langchain_core.tools import tool |
| 4 | +import operator |
| 5 | +from dotenv import load_dotenv |
| 6 | +from langchain_openai import ChatOpenAI |
| 7 | + |
| 8 | +from langchain import hub |
| 9 | +from langchain.agents import create_openai_tools_agent |
| 10 | +import json |
| 11 | +from langgraph.graph import StateGraph, END |
| 12 | +from langtrace_python_sdk import langtrace, with_langtrace_root_span |
| 13 | + |
| 14 | +load_dotenv() |
| 15 | + |
| 16 | +langtrace.init(write_spans_to_console=False) |
| 17 | + |
| 18 | + |
| 19 | +class AgentState(TypedDict): |
| 20 | + input: str |
| 21 | + agent_out: Union[AgentAction, AgentFinish, None] |
| 22 | + intermediate_steps: Annotated[list[tuple[AgentAction, str]], operator.add] |
| 23 | + |
| 24 | + |
| 25 | +ehi_information = """Title: EHI: End-to-end Learning of Hierarchical Index for |
| 26 | +Efficient Dense Retrieval |
| 27 | +Summary: Dense embedding-based retrieval is now the industry |
| 28 | +standard for semantic search and ranking problems, like obtaining relevant web |
| 29 | +documents for a given query. Such techniques use a two-stage process: (a) |
| 30 | +contrastive learning to train a dual encoder to embed both the query and |
| 31 | +documents and (b) approximate nearest neighbor search (ANNS) for finding similar |
| 32 | +documents for a given query. These two stages are disjoint; the learned |
| 33 | +embeddings might be ill-suited for the ANNS method and vice-versa, leading to |
| 34 | +suboptimal performance. In this work, we propose End-to-end Hierarchical |
| 35 | +Indexing -- EHI -- that jointly learns both the embeddings and the ANNS |
| 36 | +structure to optimize retrieval performance. EHI uses a standard dual encoder |
| 37 | +model for embedding queries and documents while learning an inverted file index |
| 38 | +(IVF) style tree structure for efficient ANNS. To ensure stable and efficient |
| 39 | +learning of discrete tree-based ANNS structure, EHI introduces the notion of |
| 40 | +dense path embedding that captures the position of a query/document in the tree. |
| 41 | +We demonstrate the effectiveness of EHI on several benchmarks, including |
| 42 | +de-facto industry standard MS MARCO (Dev set and TREC DL19) datasets. For |
| 43 | +example, with the same compute budget, EHI outperforms state-of-the-art (SOTA) |
| 44 | +in by 0.6% (MRR@10) on MS MARCO dev set and by 4.2% (nDCG@10) on TREC DL19 |
| 45 | +benchmarks. |
| 46 | +Author(s): Ramnath Kumar, Anshul Mittal, Nilesh Gupta, Aditya Kusupati, |
| 47 | +Inderjit Dhillon, Prateek Jain |
| 48 | +Source: https://arxiv.org/pdf/2310.08891.pdf""" |
| 49 | + |
| 50 | + |
| 51 | +@tool("search") |
| 52 | +def search_tool(query: str): |
| 53 | + """Searches for information on the topic of artificial intelligence (AI). |
| 54 | + Cannot be used to research any other topics. Search query must be provided |
| 55 | + in natural language and be verbose.""" |
| 56 | + # this is a "RAG" emulator |
| 57 | + return ehi_information |
| 58 | + |
| 59 | + |
| 60 | +@tool("final_answer") |
| 61 | +def final_answer_tool(answer: str, source: str): |
| 62 | + """Returns a natural language response to the user in `answer`, and a |
| 63 | + `source` which provides citations for where this information came from. |
| 64 | + """ |
| 65 | + return "" |
| 66 | + |
| 67 | + |
| 68 | +llm = ChatOpenAI() |
| 69 | +prompt = hub.pull("hwchase17/openai-functions-agent") |
| 70 | + |
| 71 | + |
| 72 | +query_agent_runnable = create_openai_tools_agent( |
| 73 | + llm=llm, tools=[final_answer_tool, search_tool], prompt=prompt |
| 74 | +) |
| 75 | + |
| 76 | + |
| 77 | +inputs = {"input": "what are EHI embeddings?", "intermediate_steps": []} |
| 78 | + |
| 79 | +agent_out = query_agent_runnable.invoke(inputs) |
| 80 | + |
| 81 | + |
| 82 | +def run_query_agent(state: list): |
| 83 | + print("> run_query_agent") |
| 84 | + agent_out = query_agent_runnable.invoke(state) |
| 85 | + return {"agent_out": agent_out} |
| 86 | + |
| 87 | + |
| 88 | +def execute_search(state: list): |
| 89 | + print("> execute_search") |
| 90 | + action = state["agent_out"] |
| 91 | + tool_call = action[-1].message_log[-1].additional_kwargs["tool_calls"][-1] |
| 92 | + out = search_tool.invoke(json.loads(tool_call["function"]["arguments"])) |
| 93 | + return {"intermediate_steps": [{"search": str(out)}]} |
| 94 | + |
| 95 | + |
| 96 | +def router(state: list): |
| 97 | + print("> router") |
| 98 | + if isinstance(state["agent_out"], list): |
| 99 | + return state["agent_out"][-1].tool |
| 100 | + else: |
| 101 | + return "error" |
| 102 | + |
| 103 | + |
| 104 | +# finally, we will have a single LLM call that MUST use the final_answer structure |
| 105 | +final_answer_llm = llm.bind_tools([final_answer_tool], tool_choice="final_answer") |
| 106 | + |
| 107 | + |
| 108 | +# this forced final_answer LLM call will be used to structure output from our |
| 109 | +# RAG endpoint |
| 110 | +def rag_final_answer(state: list): |
| 111 | + print("> final_answer") |
| 112 | + query = state["input"] |
| 113 | + context = state["intermediate_steps"][-1] |
| 114 | + |
| 115 | + prompt = f"""You are a helpful assistant, answer the user's question using the |
| 116 | + context provided. |
| 117 | +
|
| 118 | + CONTEXT: {context} |
| 119 | +
|
| 120 | + QUESTION: {query} |
| 121 | + """ |
| 122 | + out = final_answer_llm.invoke(prompt) |
| 123 | + function_call = out.additional_kwargs["tool_calls"][-1]["function"]["arguments"] |
| 124 | + return {"agent_out": function_call} |
| 125 | + |
| 126 | + |
| 127 | +# we use the same forced final_answer LLM call to handle incorrectly formatted |
| 128 | +# output from our query_agent |
| 129 | +def handle_error(state: list): |
| 130 | + print("> handle_error") |
| 131 | + query = state["input"] |
| 132 | + prompt = f"""You are a helpful assistant, answer the user's question. |
| 133 | +
|
| 134 | + QUESTION: {query} |
| 135 | + """ |
| 136 | + out = final_answer_llm.invoke(prompt) |
| 137 | + function_call = out.additional_kwargs["tool_calls"][-1]["function"]["arguments"] |
| 138 | + return {"agent_out": function_call} |
| 139 | + |
| 140 | + |
| 141 | +@with_langtrace_root_span("run_graph") |
| 142 | +def run_graph(): |
| 143 | + graph = StateGraph(AgentState) |
| 144 | + |
| 145 | + # we have four nodes that will consume our agent state and modify |
| 146 | + # our agent state based on some internal process |
| 147 | + graph.add_node("query_agent", run_query_agent) |
| 148 | + graph.add_node("search", execute_search) |
| 149 | + graph.add_node("error", handle_error) |
| 150 | + graph.add_node("rag_final_answer", rag_final_answer) |
| 151 | + # our graph will always begin with the query agent |
| 152 | + graph.set_entry_point("query_agent") |
| 153 | + # conditional edges are controlled by our router |
| 154 | + graph.add_conditional_edges( |
| 155 | + "query_agent", |
| 156 | + router, |
| 157 | + { |
| 158 | + "search": "search", |
| 159 | + "error": "error", |
| 160 | + "final_answer": END, |
| 161 | + }, |
| 162 | + ) |
| 163 | + graph.add_edge("search", "rag_final_answer") |
| 164 | + graph.add_edge("error", END) |
| 165 | + graph.add_edge("rag_final_answer", END) |
| 166 | + |
| 167 | + runnable = graph.compile() |
| 168 | + |
| 169 | + return runnable.invoke({"input": "what are EHI embeddings?"}) |
| 170 | + |
| 171 | + |
| 172 | +if __name__ == "__main__": |
| 173 | + run_graph() |
0 commit comments