Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
326 changes: 326 additions & 0 deletions tutorials/quickstart/graph_rag_with_milvus.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,173 @@
" relationid_2_passageids[relations.index(relation)].append(passage_id)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FUAm4-guPnzT"
},
"source": [
"**Optional**: run the following three cells to visualize graph knowledge for the nano_dataset"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TAAHENTxfJLO"
},
"source": [
"> If you are using Google Colab, skip the following pip install as matplotlib and networkx is built in and textwrap is standard Python library."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "08kEDCmIecdg"
},
"outputs": [],
"source": [
"! pip install matplotlib networkx"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "C6Nm4UTVGquY"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import networkx as nx\n",
"import textwrap"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 758
},
"id": "0DiR9RDdDmGQ",
"outputId": "3ebca80a-b0b5-4ee5-a645-ab330f85ae8a"
},
"outputs": [],
"source": [
"G_full = nx.Graph()\n",
"\n",
"entity_labels_full = {i: entities[i] for i in range(len(entities))}\n",
"relation_labels_full = {i: relations[i] for i in range(len(relations))}\n",
"\n",
"entity_nodes_full = []\n",
"relation_nodes_full = []\n",
"\n",
"for e in range(len(entities)):\n",
" n = f\"E:{entity_labels_full[e]}\"\n",
" G_full.add_node(n, kind=\"entity\", eid=e)\n",
" entity_nodes_full.append(n)\n",
"\n",
"for r in range(len(relations)):\n",
" n = f\"R:{relation_labels_full[r]}\"\n",
" G_full.add_node(n, kind=\"relation\", rid=r)\n",
" relation_nodes_full.append(n)\n",
"\n",
"for e, rel_ids in entityid_2_relationids.items():\n",
" e_name = f\"E:{entity_labels_full[e]}\"\n",
" for r in rel_ids:\n",
" r_name = f\"R:{relation_labels_full[r]}\"\n",
" G_full.add_edge(e_name, r_name)\n",
"\n",
"\n",
"def wrap_label(s, width=28):\n",
" return textwrap.fill(s, width=width)\n",
"\n",
"\n",
"display_labels_full = {}\n",
"for n in G_full.nodes:\n",
" if n.startswith(\"E:\"):\n",
" display_labels_full[n] = \"E: \" + wrap_label(n[2:], 28)\n",
" elif n.startswith(\"R:\"):\n",
" display_labels_full[n] = \"R: \" + wrap_label(n[2:], 28)\n",
" else:\n",
" display_labels_full[n] = wrap_label(n, 28)\n",
"\n",
"pos_full = {}\n",
"\n",
"left_x, right_x = -1.0, 1.0\n",
"available_height = 1.9\n",
"\n",
"\n",
"def side_positions(nodes):\n",
" if not nodes:\n",
" return {}\n",
" units = []\n",
" for n in nodes:\n",
" lines = display_labels_full[n].count(\"\\n\") + 1\n",
" units.append(lines + 0.6)\n",
" unit_h = available_height / max(sum(units), 1.0)\n",
" y = 0.95\n",
" y_map = {}\n",
" for n, u in zip(nodes, units):\n",
" y_map[n] = y\n",
" y -= u * unit_h\n",
" return y_map\n",
"\n",
"left_y_map = side_positions(entity_nodes_full)\n",
"right_y_map = side_positions(relation_nodes_full)\n",
"\n",
"for n in entity_nodes_full:\n",
" pos_full[n] = (left_x, left_y_map[n])\n",
"\n",
"for n in relation_nodes_full:\n",
" pos_full[n] = (right_x, right_y_map[n])\n",
"\n",
"label_pos_full = {\n",
" n: ((-1.02 if n.startswith(\"E:\") else 1.02), pos_full[n][1]) for n in G_full.nodes\n",
"}\n",
"\n",
"node_colors_full = []\n",
"node_sizes_full = []\n",
"for n, data in G_full.nodes(data=True):\n",
" if data.get(\"kind\") == \"entity\":\n",
" node_colors_full.append(\"#1f77b4\")\n",
" node_sizes_full.append(650)\n",
" else:\n",
" node_colors_full.append(\"#ffbb78\")\n",
" node_sizes_full.append(600)\n",
"\n",
"plt.figure(figsize=(20, 12))\n",
"ax = plt.gca()\n",
"nx.draw_networkx_nodes(\n",
" G_full, \n",
" pos_full, \n",
" node_color=node_colors_full, \n",
" node_size=node_sizes_full, \n",
" ax=ax\n",
")\n",
"nx.draw_networkx_edges(\n",
" G_full, \n",
" pos_full,\n",
" edge_color=\"#cfcfcf\",\n",
" width=1.2,\n",
" ax=ax, \n",
" alpha=0.9\n",
")\n",
"nx.draw_networkx_labels(\n",
" G_full, \n",
" label_pos_full, \n",
" labels=display_labels_full, \n",
" font_size=8, \n",
" ax=ax,\n",
" bbox=dict(boxstyle=\"round,pad=0.2\", fc=\"white\", ec=\"none\", alpha=0.9)\n",
")\n",
"ax.set_axis_off()\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
Expand Down Expand Up @@ -632,6 +799,165 @@
"We have get the candidate relationships by expanding the subgraph, which will be reranked by LLM in the next step."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "B5DtaEx-bWuu"
},
"source": [
"**Optional**: run the following three cells to visualize candidates subgraph"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9VD1ZF8Ifwuy"
},
"source": [
"> If you are using Google Colab, skip the following pip install as matplotlib and networkx is built in and textwrap is standard Python library."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "PjnwcGIUblwb"
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import networkx as nx\n",
"import textwrap"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 758
},
"id": "qcucyXBTQMeA",
"outputId": "051f520d-e4c4-4124-f4a3-ef4347272333"
},
"outputs": [],
"source": [
"G_sub = nx.Graph()\n",
"\n",
"relation_candidate_id_set = set(relation_candidate_ids)\n",
"\n",
"entity_labels_sub = {i: entities[i] for i in range(len(entities))}\n",
"relation_labels_sub = {i: relations[i] for i in range(len(relations))}\n",
"\n",
"entity_nodes_sub = []\n",
"relation_nodes_sub = []\n",
"\n",
"for r in relation_candidate_ids:\n",
" n = f\"R:{relation_labels_sub[r]}\"\n",
" G_sub.add_node(n, kind=\"relation\", rid=r)\n",
" relation_nodes_sub.append(n)\n",
"\n",
"for e, rel_ids in entityid_2_relationids.items():\n",
"\n",
" sub_rels = [r for r in rel_ids if r in relation_candidate_id_set]\n",
" if not sub_rels:\n",
" continue\n",
" e_name = f\"E:{entity_labels_sub[e]}\"\n",
" if e_name not in G_sub:\n",
" G_sub.add_node(e_name, kind=\"entity\", eid=e)\n",
" entity_nodes_sub.append(e_name)\n",
" for r in sub_rels:\n",
" r_name = f\"R:{relation_labels_sub[r]}\"\n",
" G_sub.add_edge(e_name, r_name)\n",
"\n",
"\n",
"def wrap_label(s, width=28):\n",
" return textwrap.fill(s, width=width)\n",
"\n",
"\n",
"display_labels_sub = {}\n",
"for n in G_sub.nodes:\n",
" if n.startswith(\"E:\"):\n",
" display_labels_sub[n] = \"E: \" + wrap_label(n[2:], 28)\n",
" elif n.startswith(\"R:\"):\n",
" display_labels_sub[n] = \"R: \" + wrap_label(n[2:], 28)\n",
" else:\n",
" display_labels_sub[n] = wrap_label(n, 28)\n",
"\n",
"pos_sub = {}\n",
"\n",
"left_x, right_x = -1.0, 1.0\n",
"available_height = 1.9\n",
"\n",
"\n",
"def side_positions(nodes):\n",
" if not nodes:\n",
" return {}\n",
" units = []\n",
" for n in nodes:\n",
" lines = display_labels_sub[n].count(\"\\n\") + 1\n",
" units.append(lines + 0.6)\n",
" unit_h = available_height / max(sum(units), 1.0)\n",
" y = 0.95\n",
" y_map = {}\n",
" for n, u in zip(nodes, units):\n",
" y_map[n] = y\n",
" y -= u * unit_h\n",
" return y_map\n",
"\n",
"left_y_map = side_positions(entity_nodes_sub)\n",
"right_y_map = side_positions(relation_nodes_sub)\n",
"\n",
"for n in entity_nodes_sub:\n",
" pos_sub[n] = (left_x, left_y_map[n])\n",
"\n",
"for n in relation_nodes_sub:\n",
" pos_sub[n] = (right_x, right_y_map[n])\n",
"\n",
"label_pos_sub = {\n",
" n: ((-1.02 if n.startswith(\"E:\") else 1.02), pos_sub[n][1]) for n in G_sub.nodes\n",
"}\n",
"\n",
"node_colors_sub = []\n",
"node_sizes_sub = []\n",
"for n, data in G_sub.nodes(data=True):\n",
" if data.get(\"kind\") == \"entity\":\n",
" node_colors_sub.append(\"#1f77b4\")\n",
" node_sizes_sub.append(650)\n",
" else:\n",
" node_colors_sub.append(\"#ffbb78\")\n",
" node_sizes_sub.append(600)\n",
"\n",
"plt.figure(figsize=(20, 12))\n",
"ax = plt.gca()\n",
"nx.draw_networkx_nodes(\n",
" G_sub, pos_sub,\n",
" node_color=node_colors_sub, \n",
" node_size=node_sizes_sub, \n",
" ax=ax\n",
")\n",
"nx.draw_networkx_edges(\n",
" G_sub, \n",
" pos_sub, \n",
" edge_color=\"#cfcfcf\", \n",
" width=1.2, \n",
" ax=ax, \n",
" alpha=0.9\n",
")\n",
"nx.draw_networkx_labels(\n",
" G_sub,\n",
" label_pos_sub,\n",
" labels=display_labels_sub, \n",
" font_size=8, \n",
" ax=ax,\n",
" bbox=dict(boxstyle=\"round,pad=0.2\", fc=\"white\", ec=\"none\", alpha=0.9)\n",
")\n",
"ax.set_axis_off()\n",
"plt.tight_layout()\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"metadata": {
Expand Down