|
| 1 | +from typing import Annotated |
| 2 | + |
| 3 | +from langchain_anthropic import ChatAnthropic |
| 4 | +from langchain_core.messages import HumanMessage |
| 5 | +from langchain_core.pydantic_v1 import BaseModel |
| 6 | +from typing_extensions import TypedDict |
| 7 | +from langchain_core.pydantic_v1 import BaseModel, Field |
| 8 | +from langchain_core.tools import Tool |
| 9 | +from langgraph.checkpoint.memory import MemorySaver |
| 10 | +from langgraph.graph import StateGraph |
| 11 | +from langgraph.graph.message import add_messages |
| 12 | +from langgraph.prebuilt import ToolNode, tools_condition |
| 13 | +from langchain_core.messages import AIMessage, ToolMessage |
| 14 | + |
| 15 | +from langtrace_python_sdk import langtrace |
| 16 | + |
| 17 | +langtrace.init() |
| 18 | + |
| 19 | +primes = {998: 7901, 999: 7907, 1000: 7919} |
| 20 | + |
| 21 | + |
| 22 | +class PrimeInput(BaseModel): |
| 23 | + n: int = Field() |
| 24 | + |
| 25 | + |
| 26 | +def is_prime(n: int) -> bool: |
| 27 | + if n <= 1 or (n % 2 == 0 and n > 2): |
| 28 | + return False |
| 29 | + for i in range(3, int(n**0.5) + 1, 2): |
| 30 | + if n % i == 0: |
| 31 | + return False |
| 32 | + return True |
| 33 | + |
| 34 | + |
| 35 | +def get_prime(n: int, primes: dict = primes) -> str: |
| 36 | + return str(primes.get(int(n))) |
| 37 | + |
| 38 | + |
| 39 | +async def aget_prime(n: int, primes: dict = primes) -> str: |
| 40 | + return str(primes.get(int(n))) |
| 41 | + |
| 42 | + |
| 43 | +class State(TypedDict): |
| 44 | + messages: Annotated[list, add_messages] |
| 45 | + # This flag is new |
| 46 | + ask_human: bool |
| 47 | + |
| 48 | + |
| 49 | +class RequestAssistance(BaseModel): |
| 50 | + """Escalate the conversation to an expert. Use this if you are unable to assist directly or if the user requires support beyond your permissions. |
| 51 | +
|
| 52 | + To use this function, relay the user's 'request' so the expert can provide the right guidance. |
| 53 | + """ |
| 54 | + |
| 55 | + request: str |
| 56 | + |
| 57 | + |
| 58 | +llm = ChatAnthropic(model="claude-3-haiku-20240307") |
| 59 | +# We can bind the llm to a tool definition, a pydantic model, or a json schema |
| 60 | +llm_with_tools = llm.bind_tools([RequestAssistance]) |
| 61 | +tools = [ |
| 62 | + Tool( |
| 63 | + name="GetPrime", |
| 64 | + func=get_prime, |
| 65 | + description="A tool that returns the `n`th prime number", |
| 66 | + args_schema=PrimeInput, |
| 67 | + coroutine=aget_prime, |
| 68 | + ), |
| 69 | +] |
| 70 | + |
| 71 | + |
| 72 | +def chatbot(state: State): |
| 73 | + response = llm_with_tools.invoke(state["messages"]) |
| 74 | + ask_human = False |
| 75 | + if ( |
| 76 | + response.tool_calls |
| 77 | + and response.tool_calls[0]["name"] == RequestAssistance.__name__ |
| 78 | + ): |
| 79 | + ask_human = True |
| 80 | + return {"messages": [response], "ask_human": ask_human} |
| 81 | + |
| 82 | + |
| 83 | +graph_builder = StateGraph(State) |
| 84 | + |
| 85 | +graph_builder.add_node("chatbot", chatbot) |
| 86 | +graph_builder.add_node("tools", ToolNode(tools=tools)) |
| 87 | + |
| 88 | + |
| 89 | +def create_response(response: str, ai_message: AIMessage): |
| 90 | + return ToolMessage( |
| 91 | + content=response, |
| 92 | + tool_call_id=ai_message.tool_calls[0]["id"], |
| 93 | + ) |
| 94 | + |
| 95 | + |
| 96 | +def human_node(state: State): |
| 97 | + new_messages = [] |
| 98 | + if not isinstance(state["messages"][-1], ToolMessage): |
| 99 | + # Typically, the user will have updated the state during the interrupt. |
| 100 | + # If they choose not to, we will include a placeholder ToolMessage to |
| 101 | + # let the LLM continue. |
| 102 | + new_messages.append( |
| 103 | + create_response("No response from human.", state["messages"][-1]) |
| 104 | + ) |
| 105 | + return { |
| 106 | + # Append the new messages |
| 107 | + "messages": new_messages, |
| 108 | + # Unset the flag |
| 109 | + "ask_human": False, |
| 110 | + } |
| 111 | + |
| 112 | + |
| 113 | +def select_next_node(state: State): |
| 114 | + if state["ask_human"]: |
| 115 | + return "human" |
| 116 | + # Otherwise, we can route as before |
| 117 | + return tools_condition(state) |
| 118 | + |
| 119 | + |
| 120 | +def basic_graph_tools(): |
| 121 | + graph_builder.add_node("human", human_node) |
| 122 | + graph_builder.add_conditional_edges( |
| 123 | + "chatbot", |
| 124 | + select_next_node, |
| 125 | + {"human": "human", "tools": "tools", "__end__": "__end__"}, |
| 126 | + ) |
| 127 | + graph_builder.add_edge("tools", "chatbot") |
| 128 | + graph_builder.add_edge("human", "chatbot") |
| 129 | + graph_builder.set_entry_point("chatbot") |
| 130 | + memory = MemorySaver() |
| 131 | + graph = graph_builder.compile( |
| 132 | + checkpointer=memory, |
| 133 | + interrupt_before=["human"], |
| 134 | + ) |
| 135 | + |
| 136 | + config = {"configurable": {"thread_id": "1"}} |
| 137 | + events = graph.stream( |
| 138 | + { |
| 139 | + "messages": [ |
| 140 | + ( |
| 141 | + "user", |
| 142 | + "I'm learning LangGraph. Could you do some research on it for me?", |
| 143 | + ) |
| 144 | + ] |
| 145 | + }, |
| 146 | + config, |
| 147 | + stream_mode="values", |
| 148 | + ) |
| 149 | + for event in events: |
| 150 | + if "messages" in event: |
| 151 | + event["messages"][-1] |
0 commit comments