|
354 | 354 | ],
|
355 | 355 | "source": [
|
356 | 356 | "# we drop sparse_values as they are not needed for this example\n",
|
357 |
| - "dataset.documents.drop(['sparse_values', 'blob'], axis=1, inplace=True)\n", |
| 357 | + "dataset.documents.drop([\"sparse_values\", \"blob\"], axis=1, inplace=True)\n", |
358 | 358 | "\n",
|
359 | 359 | "dataset.head()"
|
360 | 360 | ]
|
|
369 | 369 | "\n",
|
370 | 370 | "print(\"Here are some example topics in our Knowledge Base:\\n\")\n",
|
371 | 371 | "for r in dataset.documents.iloc[:].to_dict(orient=\"records\"):\n",
|
372 |
| - " topics.add(r['metadata']['title'])\n", |
| 372 | + " topics.add(r[\"metadata\"][\"title\"])\n", |
373 | 373 | "\n",
|
374 | 374 | "for topic in sorted(topics)[50:75]:\n",
|
375 | 375 | " print(f\"- {topic}\")"
|
|
396 | 396 | "\n",
|
397 | 397 | "if not os.environ.get(\"PINECONE_API_KEY\"):\n",
|
398 | 398 | " from pinecone_notebooks.colab import Authenticate\n",
|
| 399 | + "\n", |
399 | 400 | " Authenticate()"
|
400 | 401 | ]
|
401 | 402 | },
|
|
464 | 465 | "source": [
|
465 | 466 | "from pinecone import ServerlessSpec\n",
|
466 | 467 | "\n",
|
467 |
| - "index_name = 'langchain-retrieval-agent-fast'\n", |
| 468 | + "index_name = \"langchain-retrieval-agent-fast\"\n", |
468 | 469 | "\n",
|
469 | 470 | "if not pc.has_index(name=index_name):\n",
|
470 | 471 | " # Create a new index\n",
|
471 | 472 | " pc.create_index(\n",
|
472 | 473 | " name=index_name,\n",
|
473 | 474 | " dimension=1536, # dimensionality of text-embedding-ada-002\n",
|
474 |
| - " metric='dotproduct',\n", |
475 |
| - " spec=ServerlessSpec(\n", |
476 |
| - " cloud='aws',\n", |
477 |
| - " region='us-east-1'\n", |
478 |
| - " )\n", |
| 475 | + " metric=\"dotproduct\",\n", |
| 476 | + " spec=ServerlessSpec(cloud=\"aws\", region=\"us-east-1\"),\n", |
479 | 477 | " )\n",
|
480 | 478 | "\n",
|
481 | 479 | "pc.describe_index(name=index_name)"
|
|
651 | 649 | "source": [
|
652 | 650 | "from langchain_openai import OpenAIEmbeddings\n",
|
653 | 651 | "\n",
|
654 |
| - "openai_api_key = os.environ.get('OPENAI_API_KEY') or 'OPENAI_API_KEY'\n", |
| 652 | + "openai_api_key = os.environ.get(\"OPENAI_API_KEY\") or \"OPENAI_API_KEY\"\n", |
655 | 653 | "\n",
|
656 |
| - "embed = OpenAIEmbeddings(\n", |
657 |
| - " model='text-embedding-ada-002',\n", |
658 |
| - " openai_api_key=openai_api_key\n", |
659 |
| - ")" |
| 654 | + "embed = OpenAIEmbeddings(model=\"text-embedding-ada-002\", openai_api_key=openai_api_key)" |
660 | 655 | ]
|
661 | 656 | },
|
662 | 657 | {
|
|
670 | 665 | "from langchain_pinecone import PineconeVectorStore\n",
|
671 | 666 | "\n",
|
672 | 667 | "pinecone_vectorstore = PineconeVectorStore(\n",
|
673 |
| - " index_name=index_name, \n", |
674 |
| - " embedding=embed, \n", |
675 |
| - " text_key=\"text\"\n", |
| 668 | + " index_name=index_name, embedding=embed, text_key=\"text\"\n", |
676 | 669 | ")"
|
677 | 670 | ]
|
678 | 671 | },
|
|
759 | 752 | "source": [
|
760 | 753 | "from pprint import pprint\n",
|
761 | 754 | "\n",
|
762 |
| - "query = \"When was the college of engineering in the University of Notre Dame established?\"\n", |
| 755 | + "query = (\n", |
| 756 | + " \"When was the college of engineering in the University of Notre Dame established?\"\n", |
| 757 | + ")\n", |
763 | 758 | "\n",
|
764 | 759 | "documents = pinecone_vectorstore.similarity_search(\n",
|
765 |
| - " query=query,\n", |
766 |
| - " k=3 # return 3 most relevant docs\n", |
| 760 | + " query=query, k=3 # return 3 most relevant docs\n", |
767 | 761 | ")\n",
|
768 | 762 | "\n",
|
769 | 763 | "for doc in documents:\n",
|
|
815 | 809 | "\n",
|
816 | 810 | "# Chat completion LLM\n",
|
817 | 811 | "llm = ChatOpenAI(\n",
|
818 |
| - " openai_api_key=openai_api_key,\n", |
819 |
| - " model_name='gpt-3.5-turbo',\n", |
820 |
| - " temperature=0.0\n", |
| 812 | + " openai_api_key=openai_api_key, model_name=\"gpt-3.5-turbo\", temperature=0.0\n", |
821 | 813 | ")"
|
822 | 814 | ]
|
823 | 815 | },
|
|
839 | 831 | "from langchain_core.runnables import RunnablePassthrough\n",
|
840 | 832 | "\n",
|
841 | 833 | "# Based on the RAG template from https://smith.langchain.com/hub/rlm/rag-prompt\n",
|
842 |
| - "template=(\n", |
| 834 | + "template = (\n", |
843 | 835 | " \"You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.\"\n",
|
844 | 836 | " \"Question: {question}\"\n",
|
845 | 837 | " \"Context: {context}\"\n",
|
846 | 838 | " \"Answer:\"\n",
|
847 | 839 | ")\n",
|
848 | 840 | "prompt = PromptTemplate(input_variables=[\"question\", \"context\"], template=template)\n",
|
849 | 841 | "\n",
|
| 842 | + "\n", |
850 | 843 | "def format_docs(docs):\n",
|
851 | 844 | " return \"\\n\\n\".join(doc.page_content for doc in docs)\n",
|
852 | 845 | "\n",
|
| 846 | + "\n", |
853 | 847 | "# Retrieval Question-Answer chain\n",
|
854 | 848 | "qa_chain = (\n",
|
855 | 849 | " {\n",
|
|
896 | 890 | }
|
897 | 891 | ],
|
898 | 892 | "source": [
|
899 |
| - "qa_chain.invoke(\"When was the college of engineering in the University of Notre Dame established?\")" |
| 893 | + "qa_chain.invoke(\n", |
| 894 | + " \"When was the college of engineering in the University of Notre Dame established?\"\n", |
| 895 | + ")" |
900 | 896 | ]
|
901 | 897 | },
|
902 | 898 | {
|
|
920 | 916 | "outputs": [],
|
921 | 917 | "source": [
|
922 | 918 | "knowledge_base_tool = qa_chain.as_tool(\n",
|
923 |
| - " name='knowledge-base',\n", |
924 |
| - " description=(\n", |
925 |
| - " 'use this tool when answering general knowledge queries to get '\n", |
926 |
| - " 'more information about the topic'\n", |
927 |
| - " )\n", |
| 919 | + " name=\"knowledge-base\",\n", |
| 920 | + " description=(\n", |
| 921 | + " \"use this tool when answering general knowledge queries to get \"\n", |
| 922 | + " \"more information about the topic\"\n", |
| 923 | + " ),\n", |
928 | 924 | ")"
|
929 | 925 | ]
|
930 | 926 | },
|
|
966 | 962 | "from langgraph.graph import StateGraph\n",
|
967 | 963 | "from langgraph.graph.message import add_messages\n",
|
968 | 964 | "\n",
|
| 965 | + "\n", |
969 | 966 | "class State(TypedDict):\n",
|
970 | 967 | " messages: Annotated[list, add_messages]\n",
|
971 | 968 | "\n",
|
| 969 | + "\n", |
972 | 970 | "graph_builder = StateGraph(State)"
|
973 | 971 | ]
|
974 | 972 | },
|
|
1001 | 999 | "tools = [knowledge_base_tool]\n",
|
1002 | 1000 | "llm_with_tools = llm.bind_tools(tools)\n",
|
1003 | 1001 | "\n",
|
| 1002 | + "\n", |
1004 | 1003 | "def chatbot(state: State):\n",
|
1005 | 1004 | " return {\"messages\": [llm_with_tools.invoke(state[\"messages\"])]}\n",
|
1006 | 1005 | "\n",
|
| 1006 | + "\n", |
1007 | 1007 | "graph_builder.add_node(\"chatbot\", chatbot)\n",
|
1008 | 1008 | "\n",
|
1009 | 1009 | "tool_node = ToolNode(tools=tools)\n",
|
|
1054 | 1054 | "source": [
|
1055 | 1055 | "def agent(user_message):\n",
|
1056 | 1056 | " config = {\"configurable\": {\"thread_id\": \"1\"}}\n",
|
1057 |
| - " \n", |
| 1057 | + "\n", |
1058 | 1058 | " # The config is the **second positional argument** to stream() or invoke()!\n",
|
1059 | 1059 | " events = graph.stream(\n",
|
1060 | 1060 | " {\"messages\": [{\"role\": \"user\", \"content\": user_message}]},\n",
|
|
0 commit comments