From 7a6b7877ce9ed8e7435d78e9b711b793858e133b Mon Sep 17 00:00:00 2001 From: subhi Date: Thu, 9 Oct 2025 02:58:09 +0400 Subject: [PATCH] Add optional graph visualization cells to Graph-RAG quickstart notebook --- .../quickstart/graph_rag_with_milvus.ipynb | 326 ++++++++++++++++++ 1 file changed, 326 insertions(+) diff --git a/tutorials/quickstart/graph_rag_with_milvus.ipynb b/tutorials/quickstart/graph_rag_with_milvus.ipynb index 63b8e2da6..c1f1db448 100644 --- a/tutorials/quickstart/graph_rag_with_milvus.ipynb +++ b/tutorials/quickstart/graph_rag_with_milvus.ipynb @@ -327,6 +327,173 @@ " relationid_2_passageids[relations.index(relation)].append(passage_id)" ] }, + { + "cell_type": "markdown", + "metadata": { + "id": "FUAm4-guPnzT" + }, + "source": [ + "**Optional**: run the following three cells to visualize graph knowledge for the nano_dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TAAHENTxfJLO" + }, + "source": [ + "> If you are using Google Colab, skip the following pip install as matplotlib and networkx is built in and textwrap is standard Python library." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "08kEDCmIecdg" + }, + "outputs": [], + "source": [ + "! pip install matplotlib networkx" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "id": "C6Nm4UTVGquY" + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import networkx as nx\n", + "import textwrap" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 758 + }, + "id": "0DiR9RDdDmGQ", + "outputId": "3ebca80a-b0b5-4ee5-a645-ab330f85ae8a" + }, + "outputs": [], + "source": [ + "G_full = nx.Graph()\n", + "\n", + "entity_labels_full = {i: entities[i] for i in range(len(entities))}\n", + "relation_labels_full = {i: relations[i] for i in range(len(relations))}\n", + "\n", + "entity_nodes_full = []\n", + "relation_nodes_full = []\n", + "\n", + "for e in range(len(entities)):\n", + " n = f\"E:{entity_labels_full[e]}\"\n", + " G_full.add_node(n, kind=\"entity\", eid=e)\n", + " entity_nodes_full.append(n)\n", + "\n", + "for r in range(len(relations)):\n", + " n = f\"R:{relation_labels_full[r]}\"\n", + " G_full.add_node(n, kind=\"relation\", rid=r)\n", + " relation_nodes_full.append(n)\n", + "\n", + "for e, rel_ids in entityid_2_relationids.items():\n", + " e_name = f\"E:{entity_labels_full[e]}\"\n", + " for r in rel_ids:\n", + " r_name = f\"R:{relation_labels_full[r]}\"\n", + " G_full.add_edge(e_name, r_name)\n", + "\n", + "\n", + "def wrap_label(s, width=28):\n", + " return textwrap.fill(s, width=width)\n", + "\n", + "\n", + "display_labels_full = {}\n", + "for n in G_full.nodes:\n", + " if n.startswith(\"E:\"):\n", + " display_labels_full[n] = \"E: \" + wrap_label(n[2:], 28)\n", + " elif n.startswith(\"R:\"):\n", + " display_labels_full[n] = \"R: \" + wrap_label(n[2:], 28)\n", + " else:\n", + " display_labels_full[n] = wrap_label(n, 28)\n", + "\n", + "pos_full = {}\n", + "\n", + "left_x, right_x = -1.0, 1.0\n", + "available_height = 1.9\n", + "\n", + "\n", + "def side_positions(nodes):\n", + " if not nodes:\n", + " return {}\n", + " units = []\n", + " for n in nodes:\n", + " lines = display_labels_full[n].count(\"\\n\") + 1\n", + " units.append(lines + 0.6)\n", + " unit_h = available_height / max(sum(units), 1.0)\n", + " y = 0.95\n", + " y_map = {}\n", + " for n, u in zip(nodes, units):\n", + " y_map[n] = y\n", + " y -= u * unit_h\n", + " return y_map\n", + "\n", + "left_y_map = side_positions(entity_nodes_full)\n", + "right_y_map = side_positions(relation_nodes_full)\n", + "\n", + "for n in entity_nodes_full:\n", + " pos_full[n] = (left_x, left_y_map[n])\n", + "\n", + "for n in relation_nodes_full:\n", + " pos_full[n] = (right_x, right_y_map[n])\n", + "\n", + "label_pos_full = {\n", + " n: ((-1.02 if n.startswith(\"E:\") else 1.02), pos_full[n][1]) for n in G_full.nodes\n", + "}\n", + "\n", + "node_colors_full = []\n", + "node_sizes_full = []\n", + "for n, data in G_full.nodes(data=True):\n", + " if data.get(\"kind\") == \"entity\":\n", + " node_colors_full.append(\"#1f77b4\")\n", + " node_sizes_full.append(650)\n", + " else:\n", + " node_colors_full.append(\"#ffbb78\")\n", + " node_sizes_full.append(600)\n", + "\n", + "plt.figure(figsize=(20, 12))\n", + "ax = plt.gca()\n", + "nx.draw_networkx_nodes(\n", + " G_full, \n", + " pos_full, \n", + " node_color=node_colors_full, \n", + " node_size=node_sizes_full, \n", + " ax=ax\n", + ")\n", + "nx.draw_networkx_edges(\n", + " G_full, \n", + " pos_full,\n", + " edge_color=\"#cfcfcf\",\n", + " width=1.2,\n", + " ax=ax, \n", + " alpha=0.9\n", + ")\n", + "nx.draw_networkx_labels(\n", + " G_full, \n", + " label_pos_full, \n", + " labels=display_labels_full, \n", + " font_size=8, \n", + " ax=ax,\n", + " bbox=dict(boxstyle=\"round,pad=0.2\", fc=\"white\", ec=\"none\", alpha=0.9)\n", + ")\n", + "ax.set_axis_off()\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, { "cell_type": "markdown", "metadata": { @@ -632,6 +799,165 @@ "We have get the candidate relationships by expanding the subgraph, which will be reranked by LLM in the next step." ] }, + { + "cell_type": "markdown", + "metadata": { + "id": "B5DtaEx-bWuu" + }, + "source": [ + "**Optional**: run the following three cells to visualize candidates subgraph" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9VD1ZF8Ifwuy" + }, + "source": [ + "> If you are using Google Colab, skip the following pip install as matplotlib and networkx is built in and textwrap is standard Python library." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PjnwcGIUblwb" + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import networkx as nx\n", + "import textwrap" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 758 + }, + "id": "qcucyXBTQMeA", + "outputId": "051f520d-e4c4-4124-f4a3-ef4347272333" + }, + "outputs": [], + "source": [ + "G_sub = nx.Graph()\n", + "\n", + "relation_candidate_id_set = set(relation_candidate_ids)\n", + "\n", + "entity_labels_sub = {i: entities[i] for i in range(len(entities))}\n", + "relation_labels_sub = {i: relations[i] for i in range(len(relations))}\n", + "\n", + "entity_nodes_sub = []\n", + "relation_nodes_sub = []\n", + "\n", + "for r in relation_candidate_ids:\n", + " n = f\"R:{relation_labels_sub[r]}\"\n", + " G_sub.add_node(n, kind=\"relation\", rid=r)\n", + " relation_nodes_sub.append(n)\n", + "\n", + "for e, rel_ids in entityid_2_relationids.items():\n", + "\n", + " sub_rels = [r for r in rel_ids if r in relation_candidate_id_set]\n", + " if not sub_rels:\n", + " continue\n", + " e_name = f\"E:{entity_labels_sub[e]}\"\n", + " if e_name not in G_sub:\n", + " G_sub.add_node(e_name, kind=\"entity\", eid=e)\n", + " entity_nodes_sub.append(e_name)\n", + " for r in sub_rels:\n", + " r_name = f\"R:{relation_labels_sub[r]}\"\n", + " G_sub.add_edge(e_name, r_name)\n", + "\n", + "\n", + "def wrap_label(s, width=28):\n", + " return textwrap.fill(s, width=width)\n", + "\n", + "\n", + "display_labels_sub = {}\n", + "for n in G_sub.nodes:\n", + " if n.startswith(\"E:\"):\n", + " display_labels_sub[n] = \"E: \" + wrap_label(n[2:], 28)\n", + " elif n.startswith(\"R:\"):\n", + " display_labels_sub[n] = \"R: \" + wrap_label(n[2:], 28)\n", + " else:\n", + " display_labels_sub[n] = wrap_label(n, 28)\n", + "\n", + "pos_sub = {}\n", + "\n", + "left_x, right_x = -1.0, 1.0\n", + "available_height = 1.9\n", + "\n", + "\n", + "def side_positions(nodes):\n", + " if not nodes:\n", + " return {}\n", + " units = []\n", + " for n in nodes:\n", + " lines = display_labels_sub[n].count(\"\\n\") + 1\n", + " units.append(lines + 0.6)\n", + " unit_h = available_height / max(sum(units), 1.0)\n", + " y = 0.95\n", + " y_map = {}\n", + " for n, u in zip(nodes, units):\n", + " y_map[n] = y\n", + " y -= u * unit_h\n", + " return y_map\n", + "\n", + "left_y_map = side_positions(entity_nodes_sub)\n", + "right_y_map = side_positions(relation_nodes_sub)\n", + "\n", + "for n in entity_nodes_sub:\n", + " pos_sub[n] = (left_x, left_y_map[n])\n", + "\n", + "for n in relation_nodes_sub:\n", + " pos_sub[n] = (right_x, right_y_map[n])\n", + "\n", + "label_pos_sub = {\n", + " n: ((-1.02 if n.startswith(\"E:\") else 1.02), pos_sub[n][1]) for n in G_sub.nodes\n", + "}\n", + "\n", + "node_colors_sub = []\n", + "node_sizes_sub = []\n", + "for n, data in G_sub.nodes(data=True):\n", + " if data.get(\"kind\") == \"entity\":\n", + " node_colors_sub.append(\"#1f77b4\")\n", + " node_sizes_sub.append(650)\n", + " else:\n", + " node_colors_sub.append(\"#ffbb78\")\n", + " node_sizes_sub.append(600)\n", + "\n", + "plt.figure(figsize=(20, 12))\n", + "ax = plt.gca()\n", + "nx.draw_networkx_nodes(\n", + " G_sub, pos_sub,\n", + " node_color=node_colors_sub, \n", + " node_size=node_sizes_sub, \n", + " ax=ax\n", + ")\n", + "nx.draw_networkx_edges(\n", + " G_sub, \n", + " pos_sub, \n", + " edge_color=\"#cfcfcf\", \n", + " width=1.2, \n", + " ax=ax, \n", + " alpha=0.9\n", + ")\n", + "nx.draw_networkx_labels(\n", + " G_sub,\n", + " label_pos_sub,\n", + " labels=display_labels_sub, \n", + " font_size=8, \n", + " ax=ax,\n", + " bbox=dict(boxstyle=\"round,pad=0.2\", fc=\"white\", ec=\"none\", alpha=0.9)\n", + ")\n", + "ax.set_axis_off()\n", + "plt.tight_layout()\n", + "plt.show()" + ] + }, { "cell_type": "markdown", "metadata": {