|
70 | 70 | }, |
71 | 71 | { |
72 | 72 | "cell_type": "code", |
73 | | - "execution_count": null, |
| 73 | + "execution_count": 12, |
74 | 74 | "metadata": {}, |
75 | | - "outputs": [], |
| 75 | + "outputs": [ |
| 76 | + { |
| 77 | + "name": "stdout", |
| 78 | + "output_type": "stream", |
| 79 | + "text": [ |
| 80 | + "Austin airport has 98 outgoing routes.\n" |
| 81 | + ] |
| 82 | + } |
| 83 | + ], |
76 | 84 | "source": [ |
77 | 85 | "from langchain_aws import ChatBedrockConverse\n", |
78 | 86 | "from langchain_aws.chains import create_neptune_opencypher_qa_chain\n", |
|
83 | 91 | " temperature=0,\n", |
84 | 92 | ")\n", |
85 | 93 | "\n", |
86 | | - "chain = create_neptune_opencypher_qa_chain(\n", |
87 | | - " llm=llm,\n", |
88 | | - " graph=graph,\n", |
89 | | - ")\n", |
| 94 | + "chain = create_neptune_opencypher_qa_chain(llm=llm, graph=graph)\n", |
| 95 | + "\n", |
| 96 | + "result = chain.invoke(\"How many outgoing routes does the Austin airport have?\")\n", |
| 97 | + "print(result[\"result\"].content)" |
| 98 | + ] |
| 99 | + }, |
| 100 | + { |
| 101 | + "cell_type": "markdown", |
| 102 | + "metadata": {}, |
| 103 | + "source": [ |
| 104 | + "### Adding Message History\n", |
| 105 | + "\n", |
| 106 | + "The Neptune openCypher QA chain has the ability to be wrapped by [`RunnableWithMessageHistory`](https://python.langchain.com/v0.2/api_reference/core/runnables/langchain_core.runnables.history.RunnableWithMessageHistory.html#langchain_core.runnables.history.RunnableWithMessageHistory). This adds message history to the chain, allowing us to create a chatbot that retains conversation state across multiple invocations.\n", |
| 107 | + "\n", |
| 108 | + "To start, we need a way to store and load the message history. For this purpose, each thread will be created as an instance of [`InMemoryChatMessageHistory`](https://python.langchain.com/api_reference/core/chat_history/langchain_core.chat_history.InMemoryChatMessageHistory.html), and stored into a dictionary for repeated access.\n", |
| 109 | + "\n", |
| 110 | + "(Also see: https://python.langchain.com/docs/versions/migrating_memory/chat_history/#chatmessagehistory)" |
| 111 | + ] |
| 112 | + }, |
| 113 | + { |
| 114 | + "cell_type": "code", |
| 115 | + "execution_count": null, |
| 116 | + "metadata": {}, |
| 117 | + "outputs": [], |
| 118 | + "source": [ |
| 119 | + "from langchain_core.chat_history import InMemoryChatMessageHistory\n", |
| 120 | + "\n", |
| 121 | + "chats_by_session_id = {}\n", |
| 122 | + "\n", |
| 123 | + "\n", |
| 124 | + "def get_chat_history(session_id: str) -> InMemoryChatMessageHistory:\n", |
| 125 | + " chat_history = chats_by_session_id.get(session_id)\n", |
| 126 | + " if chat_history is None:\n", |
| 127 | + " chat_history = InMemoryChatMessageHistory()\n", |
| 128 | + " chats_by_session_id[session_id] = chat_history\n", |
| 129 | + " return chat_history" |
| 130 | + ] |
| 131 | + }, |
| 132 | + { |
| 133 | + "cell_type": "markdown", |
| 134 | + "metadata": {}, |
| 135 | + "source": [ |
| 136 | + "Now, the QA chain and message history storage can be used to create the new `RunnableWithMessageHistory`. Note that we must set `query` as the input key to match the format expected by the base chain." |
| 137 | + ] |
| 138 | + }, |
| 139 | + { |
| 140 | + "cell_type": "code", |
| 141 | + "execution_count": null, |
| 142 | + "metadata": {}, |
| 143 | + "outputs": [], |
| 144 | + "source": [ |
| 145 | + "from langchain_core.runnables.history import RunnableWithMessageHistory\n", |
| 146 | + "\n", |
| 147 | + "runnable_with_history = RunnableWithMessageHistory(\n", |
| 148 | + " chain,\n", |
| 149 | + " get_chat_history,\n", |
| 150 | + " input_messages_key=\"query\",\n", |
| 151 | + ")" |
| 152 | + ] |
| 153 | + }, |
| 154 | + { |
| 155 | + "cell_type": "markdown", |
| 156 | + "metadata": {}, |
| 157 | + "source": [ |
| 158 | + "Before invoking the chain, a unique `session_id` needs to be generated for the conversation that the new `InMemoryChatMessageHistory` will remember." |
| 159 | + ] |
| 160 | + }, |
| 161 | + { |
| 162 | + "cell_type": "code", |
| 163 | + "execution_count": null, |
| 164 | + "metadata": {}, |
| 165 | + "outputs": [], |
| 166 | + "source": [ |
| 167 | + "import uuid\n", |
90 | 168 | "\n", |
91 | | - "result = chain.invoke(\n", |
92 | | - " {\"query\": \"How many outgoing routes does the Austin airport have?\"}\n", |
| 169 | + "session_id = uuid.uuid4()" |
| 170 | + ] |
| 171 | + }, |
| 172 | + { |
| 173 | + "cell_type": "markdown", |
| 174 | + "metadata": {}, |
| 175 | + "source": [ |
| 176 | + "Finally, invoke the message history enabled chain with the `session_id`." |
| 177 | + ] |
| 178 | + }, |
| 179 | + { |
| 180 | + "cell_type": "code", |
| 181 | + "execution_count": 8, |
| 182 | + "metadata": {}, |
| 183 | + "outputs": [ |
| 184 | + { |
| 185 | + "name": "stdout", |
| 186 | + "output_type": "stream", |
| 187 | + "text": [ |
| 188 | + "You can fly directly to 98 destinations from Austin airport.\n" |
| 189 | + ] |
| 190 | + } |
| 191 | + ], |
| 192 | + "source": [ |
| 193 | + "result = runnable_with_history.invoke(\n", |
| 194 | + " {\"query\": \"How many destinations can I fly to directly from Austin airport?\"},\n", |
| 195 | + " config={\"configurable\": {\"session_id\": session_id}},\n", |
| 196 | + ")\n", |
| 197 | + "print(result[\"result\"].content)" |
| 198 | + ] |
| 199 | + }, |
| 200 | + { |
| 201 | + "cell_type": "markdown", |
| 202 | + "metadata": {}, |
| 203 | + "source": [ |
| 204 | + "As the chain continues to be invoked with the same `session_id`, responses will be returned in the context of previous queries in the conversation.\n" |
| 205 | + ] |
| 206 | + }, |
| 207 | + { |
| 208 | + "cell_type": "code", |
| 209 | + "execution_count": 9, |
| 210 | + "metadata": {}, |
| 211 | + "outputs": [ |
| 212 | + { |
| 213 | + "name": "stdout", |
| 214 | + "output_type": "stream", |
| 215 | + "text": [ |
| 216 | + "You can fly directly to 4 destinations in Europe from Austin airport.\n" |
| 217 | + ] |
| 218 | + } |
| 219 | + ], |
| 220 | + "source": [ |
| 221 | + "result = runnable_with_history.invoke(\n", |
| 222 | + " {\"query\": \"Out of those destinations, how many are in Europe?\"},\n", |
| 223 | + " config={\"configurable\": {\"session_id\": session_id}},\n", |
| 224 | + ")\n", |
| 225 | + "print(result[\"result\"].content)" |
| 226 | + ] |
| 227 | + }, |
| 228 | + { |
| 229 | + "cell_type": "code", |
| 230 | + "execution_count": 10, |
| 231 | + "metadata": {}, |
| 232 | + "outputs": [ |
| 233 | + { |
| 234 | + "name": "stdout", |
| 235 | + "output_type": "stream", |
| 236 | + "text": [ |
| 237 | + "The four European destinations you can fly to directly from Austin airport are:\n", |
| 238 | + "- AMS (Amsterdam Airport Schiphol)\n", |
| 239 | + "- FRA (Frankfurt am Main)\n", |
| 240 | + "- LGW (London Gatwick)\n", |
| 241 | + "- LHR (London Heathrow)\n" |
| 242 | + ] |
| 243 | + } |
| 244 | + ], |
| 245 | + "source": [ |
| 246 | + "result = runnable_with_history.invoke(\n", |
| 247 | + " {\"query\": \"Give me the codes and names of those airports.\"},\n", |
| 248 | + " config={\"configurable\": {\"session_id\": session_id}},\n", |
93 | 249 | ")\n", |
94 | 250 | "print(result[\"result\"].content)" |
95 | 251 | ] |
96 | 252 | } |
97 | 253 | ], |
98 | 254 | "metadata": { |
99 | 255 | "kernelspec": { |
100 | | - "display_name": "Python 3 (ipykernel)", |
| 256 | + "display_name": "Python 3", |
101 | 257 | "language": "python", |
102 | 258 | "name": "python3" |
103 | 259 | }, |
|
111 | 267 | "name": "python", |
112 | 268 | "nbconvert_exporter": "python", |
113 | 269 | "pygments_lexer": "ipython3", |
114 | | - "version": "3.10.12" |
| 270 | + "version": "3.10.13" |
115 | 271 | } |
116 | 272 | }, |
117 | 273 | "nbformat": 4, |
|
0 commit comments