diff --git a/content/TinyGraphRAG/example/data.md b/content/TinyGraphRAG/example/data.md new file mode 100644 index 0000000..a90d9fd --- /dev/null +++ b/content/TinyGraphRAG/example/data.md @@ -0,0 +1,15 @@ +# Introduction + +## 1.1 Introduction + +Following a drizzling, we take a walk on the wet street. Feeling the gentle breeze and seeing the sunset glow, we bet the weather must be nice tomorrow. Walking to a fruit stand, we pick up a green watermelon with curly root and muffled sound; while hoping the watermelon is ripe, we also expect some good aca- demic marks this semester after all the hard work on studies. We wish readers to share the same confidence in their studies, but to begin with, let us take an informal discussion on what is machine learning . + +Taking a closer look at the scenario described above, we notice that it involves many experience-based predictions. For example, why would we expect beautiful weather tomorrow after observing the gentle breeze and sunset glow? We expect this beautiful weather because,from our experience,theweather on the following day is often beautiful when we experience such a scene in the present day. Also, why do we pick the watermelon with green color, curly root, and muffled sound? It is because we have eaten and enjoyed many watermelons, and those sat- isfying the above criteria are usually ripe. Similarly, our learn- ing experience tells us that hard work leads to good academic marks. We are confident in our predictions because we learned from experience and made experience-based decisions. + +Mitchell ( 1997 ) provides a more formal definition: ‘‘A computer program is said to learn from experience $E$ for some class of tasks $T$ and performance measure $P$ , if its performance at tasks in $T$ , as measured by $P$ , improves with experience $E$ .’’ + +E.g., Hand et al. ( 2001 ). + +While humans learn from experience, can computers do the same? The answer is ‘‘yes’’, and machine learning is what we need. Machine learning is the technique that improves system performance by learning from experience via computational methods. In computer systems, experience exists in the form of data, and the main task of machine learning is to develop learning algorithms that build models from data. By feeding the learning algorithm with experience data, we obtain a model that can make predictions (e.g., the watermelon is ripe) on new observations (e.g., an uncut watermelon). If we consider com- puter science as the subject of algorithms, then machine learn- ing is the subject of learning algorithms . + +In this book, we use ‘‘model’’ as a general term for the out- come learned from data. In some other literature, the term ‘‘model’’may refer to the global outcome (e.g., a decision tree), while the term ‘‘pattern’’ refers to the local outcome (e.g., a single rule). \ No newline at end of file diff --git a/content/TinyGraphRAG/help.ipynb b/content/TinyGraphRAG/help.ipynb new file mode 100644 index 0000000..f01e88c --- /dev/null +++ b/content/TinyGraphRAG/help.ipynb @@ -0,0 +1,175 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/calvin-lucas/Documents/DataWhale_Learning_Material/tiny-graphrag\n" + ] + } + ], + "source": [ + "# 注意:重新运行前需要:重启整个内核\n", + "import os\n", + "import sys\n", + "sys.path.append('.') # 添加当前目录到 Python 路径\n", + "print(os.getcwd()) # 验证下当前工作路径" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# 导入模块\n", + "from tinygraph.graph import TinyGraph\n", + "from tinygraph.embedding.zhipu import zhipuEmb\n", + "from tinygraph.llm.zhipu import zhipuLLM\n", + "\n", + "from neo4j import GraphDatabase\n", + "from dotenv import load_dotenv # 用于加载环境变量" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# 配置使用的 LLM 和 Embedding 服务,现在只支持 ZhipuAI\n", + "# 加载 .env文件, 从而导入api_key\n", + "load_dotenv() # 加载工作目录下的 .env 文件\n", + "\n", + "emb = zhipuEmb(\n", + " model_name=\"embedding-2\", # 嵌入模型\n", + " api_key=os.getenv('API_KEY')\n", + ")\n", + "llm = zhipuLLM(\n", + " model_name=\"glm-3-turbo\", # LLM 模型\n", + " api_key=os.getenv('API_KEY')\n", + ")\n", + "graph = TinyGraph(\n", + " url=\"neo4j://localhost:7687\",\n", + " username=\"neo4j\",\n", + " password=\"neo4j-passwordTGR\", # 初次登陆的默认密码为neo4j,此后需修改再使用\n", + " llm=llm,\n", + " emb=emb,\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Document 'example/data.md' has already been loaded, skipping import process.\n" + ] + } + ], + "source": [ + "# 使用 TinyGraph 添加文档。目前支持所有文本格式的文件。这一步的时间可能较长;\n", + "# 结束后,在当前目录下会生成一个 `workspace` 文件夹,包含 `community`、`chunk` 和 `doc` 信息\n", + "graph.add_document(\"example/data.md\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "数据库连接正常,节点数量: 29\n" + ] + } + ], + "source": [ + "# 再次验证数据库连接\n", + "with graph.driver.session() as session:\n", + " result = session.run(\"MATCH (n) RETURN count(n) as count\")\n", + " count = result.single()[\"count\"]\n", + " print(f\"数据库连接正常,节点数量: {count}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "本地查询结果:\n", + "The term \"dl\" is not explicitly defined in the provided context. However, based on the context's focus on machine learning, \"dl\" might commonly be interpreted as an abbreviation for \"deep learning,\" which is a subset of machine learning that involves neural networks with many layers (hence \"deep\"). Deep learning has become a prominent field, particularly in the realm of artificial intelligence, where it is used to recognize patterns and make predictions from large datasets.\n", + "\n", + "If \"dl\" refers to something else in the context of the user query, there would be no information to discern its meaning without further clarification or additional context.\n" + ] + } + ], + "source": [ + "# 执行局部查询测试\n", + "local_res = graph.local_query(\"what is dl?\")\n", + "print(\"\\n本地查询结果:\")\n", + "print(local_res)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "全局查询结果:\n", + "The term 'dl' is not explicitly mentioned in the provided data tables. Therefore, I don't know what 'dl' refers to in the context of the user's question. If 'dl' stands for 'Deep Learning,' it is a subset of machine learning that uses neural networks with many layers for feature extraction and modeling. However, this context is not provided in the data tables.\n" + ] + } + ], + "source": [ + "\n", + "# 执行全局查询测试\n", + "global_res = graph.global_query(\"what is dl?\")\n", + "print(\"\\n全局查询结果:\")\n", + "print(global_res)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "TinyGraphRAG_2025-04-08", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git "a/content/TinyGraphRAG/images/Learning-Algorithms\350\212\202\347\202\271\347\232\204\350\257\246\347\273\206\344\277\241\346\201\257.png" "b/content/TinyGraphRAG/images/Learning-Algorithms\350\212\202\347\202\271\347\232\204\350\257\246\347\273\206\344\277\241\346\201\257.png" new file mode 100644 index 0000000..dec5876 Binary files /dev/null and "b/content/TinyGraphRAG/images/Learning-Algorithms\350\212\202\347\202\271\347\232\204\350\257\246\347\273\206\344\277\241\346\201\257.png" differ diff --git "a/content/TinyGraphRAG/images/Tiny-Graphrag\346\265\201\347\250\213\345\233\276V2.png" "b/content/TinyGraphRAG/images/Tiny-Graphrag\346\265\201\347\250\213\345\233\276V2.png" new file mode 100644 index 0000000..366db93 Binary files /dev/null and "b/content/TinyGraphRAG/images/Tiny-Graphrag\346\265\201\347\250\213\345\233\276V2.png" differ diff --git "a/content/TinyGraphRAG/images/\345\233\276\346\225\260\346\215\256\345\272\223\347\244\272\344\276\213.png" "b/content/TinyGraphRAG/images/\345\233\276\346\225\260\346\215\256\345\272\223\347\244\272\344\276\213.png" new file mode 100644 index 0000000..ea17303 Binary files /dev/null and "b/content/TinyGraphRAG/images/\345\233\276\346\225\260\346\215\256\345\272\223\347\244\272\344\276\213.png" differ diff --git "a/content/TinyGraphRAG/images/\346\237\245\350\257\242\347\273\223\346\236\234\347\244\272\344\276\213.png" "b/content/TinyGraphRAG/images/\346\237\245\350\257\242\347\273\223\346\236\234\347\244\272\344\276\213.png" new file mode 100644 index 0000000..b210ad0 Binary files /dev/null and "b/content/TinyGraphRAG/images/\346\237\245\350\257\242\347\273\223\346\236\234\347\244\272\344\276\213.png" differ diff --git a/content/TinyGraphRAG/readme.md b/content/TinyGraphRAG/readme.md new file mode 100644 index 0000000..1de7a6e --- /dev/null +++ b/content/TinyGraphRAG/readme.md @@ -0,0 +1,631 @@ +# Tiny-Graphrag使用指南与代码解读 +>此README包括两部分:1.引言;2.正文 +## 引言: +- Tiny-Graphrag是一个基于Graphrag的简化版本,包含了Graphrag的核心功能: 1.知识图谱构建;2.图检索优化;3.生成增强。创建Graphrag项目的目的是帮助大家理解Graphrag的原理并提供Demo来实现。 +- 本项目实现流程如下所示: + +
+ +
+ +- 用通俗语言来描述就是:**输入问题后,通过图结构运算层的计算,将得到的上下文交给一个“聪明的学生”(即大语言模型 LLM),让它基于这些上下文进行推理和回答问题。** +## 正文: +>正文包括三部分:1.Graphrag简要介绍;2.Tiny-Graphrag 使用方法;3.Tiny-Graphrag代码解读 +### Graphrag简要介绍 +--- +- 是什么? + - 基于知识图谱的检索增强生成技术,通过显式建模实体关系提升rag的多跳推理能力。 +- 提出时能够解决什么问题? + - 传统rag的局限:单跳检索(无法回答"特朗普和拜登的母校有何共同点?"类问题) 语义相似度≠逻辑相关性 + - Graphrag的改进:通过图路径实现多跳推理(如"特朗普→宾大→法学←拜登") +- 以微软Graphrag为例,其核心功能如下表所示: + +| 模块 | 模块描述 | +|:------|:-----| +| 知识图谱构建 | 核心功能之一,将文本或结构化数据转化为图结构(节点和边)。 | +| 图检索优化 | 基于图谱的拓扑关系(如多跳路径、子图匹配)改进传统向量检索。 | +| 生成增强 | 利用检索到的图结构(如子图、路径)增强大模型的生成逻辑性和准确性。 | + +- Leiden算法介绍 + - Leiden 算法是一种用于社区检测的高效算法,是 Louvain 算法的改进版本。它解决了 Louvain 算法可能产生的不连通社区问题,并提供了更高质量的社区划分。本文 1.3.7 生成社区内容 对此有进一步的介绍,更深入的了解请访问: https://arxiv.org/pdf/1810.08473 + +### Tiny-Graphrag 使用方法 +--- + - 本项目给出了Tiny-Graphrag使用方法,初学者可以先直接跑通这个程序,然后再继续了解具体原理。这样的学习曲线更缓和,能有效防止卡在代码理解层面而对代码的整体作用缺少理解,难以应用。下面给出Tiny-Graphrag使用的具体方法。 + - Tiny-Graphrag 使用方法 + - 个人主机环境:ubuntu24.04 + - 代码下载 + ```bash + git clone https://github.com/limafang/tiny-graphrag.git + cd tiny-graphrag + ``` + - 主机环境配置 + 1. 安装:`neo4j --version 5.26.5`,可使用wget配合pip来完成 + 2. 安装插件:`GDS`。 可从github上找到与`neo4j 5.26.5`**兼容**的`GDS 2.13.2`,将这个.jar文件放到neo4j的插件文件夹里。 + 3. 安装:`OpenJDK-21`。命令行`sudo apt install openjdk-21-jre-headless` + - 使用conda创建虚拟环境(虚拟环境创建此处仅作参考,学习者可以使用自己常用的开发环境来运行) + ```bash + conda create --name tinygrag python=3.10 -y # 虚拟环境创建 + conda activate tinygrag # 命令行激活虚拟环境 + conda install pip -y # 在conda环境内安装pip包管理工具 + ``` + - 环境中安装requirements.txt中的依赖,对应命令行为`pip install -r requirements.txt` + - 先运行Neo4j,命令行为:`sudo neo4j start`,然后在浏览器中登陆到neo4j,默认网址为:http://localhost:7474 + - 运行`help.ipynb` + - 注意每次全部重新运行都需要重启内核,否则在本地查询等步骤会报错 + - 使用本电脑首次运行完成耗时15分钟 + - 对于非首次运行的打开过程为: + 1. 激活当前项目的对应虚拟环境 + 2. 打开neo4j + 3. 运行`help.ipynb` + - 其他要求: + - 本项目以zhipuAI作为调用的大模型,需要调用其API,所以需要注册智谱API的帐号,从而获得API +### Tiny-Graphrag代码解读 +--- +>下面将按照Graphrag的三个核心功能来介绍本项目的代码: +#### 1. 知识图谱构建 + +- 运行代码前需要启动neo4j客户端。 +- 模块导入,并添加API,其中API可以手动添加,也可以通过将API设置为环境变量的方法添加,本项目采用后者。 + ```python + # 导入模块 + import os + import sys + + from Tiny-Graph.graph import Tiny-Graph + from Tiny-Graph.embedding.zhipu import zhipuEmb + from Tiny-Graph.llm.zhipu import zhipuLLM + + from neo4j import GraphDatabase + from dotenv import load_dotenv # 用于加载环境变量 + + sys.path.append('.') # 添加当前目录到 Python 路径 + print(os.getcwd()) # 验证下当前工作路径 + + # 加载 .env文件, 从而导入api_key + load_dotenv() # 加载工作目录下的 .env 文件 + ``` +##### 1.1 emb、llm类的实例化 +- 将zhipuAi的嵌入模型(zhipuEmb)、zhipuLLM以及Tiny-Graph类分别实例化: +- llm以及模型的embedding服务,依次完成实例化。其中的llm以及embedding可以根据自己的需要再调整,此处作为示例用,两者分别传入了嵌入模型 / LLM模型的名称以及API_KEY +- 对应代码 + ```python + emb = zhipuEmb( + model_name="embedding-2", # 嵌入模型 + api_key=os.getenv('API_KEY') + ) + llm = zhipuLLM( + model_name="glm-3-turbo", # LLM + api_key=os.getenv('API_KEY') + ) + ``` +- 以`zhipuEmb`为例,分析下类的继承关系。此处的`zhipuEmb`类是继承于`BaseEmb`类,在类实例化的过程(此处为`emb = zhipuEmb`)中会先调用`__init__`方法; + ```python + class zhipuEmb(BaseEmb): + def __init__(self, model_name: str, api_key: str, **kwargs): + super().__init__(model_name=model_name, **kwargs) + self.client = ZhipuAI(api_key=api_key) # 创建 ZhipuAI 客户端,self.client 是zhipuEmb类的一个属性 + + def get_emb(self, text: str) -> List[float]: + emb = self.client.embeddings.create( + model=self.model_name, + input=text, + ) + return emb.data[0].embedding + ``` +- 为了调用`zhipuEmb`继承的`BaseEmb`类的属性,使用`super().__init__(model_name=model_name, **kwargs)`将模型名称传入`zhipuEmb`继承的`BaseEmb`类; +- 而`BaseEmb`类继承自`ABC`类(`Abstract Base Class`,抽象基类) +- `zhipuLLM`的实例化过程与此类似。 +##### 1.2 Tiny-Graph类的实例化 +- 传入了neo4j的默认网址、用户名、密码、llm、emb。 +- 对应代码 + ```python + graph = Tiny-Graph( + url="neo4j://localhost:7687", + username="neo4j", + password="neo4j-passwordTGR", + llm=llm, + emb=emb, + ) + ``` +- 实例化过程自动调用的`__init__`方法完成了创建Neo4j数据库驱动、设置语言模型、设置嵌入模型、设置工作目录等工作,详细注释见下方代码: + ```python + class Tiny-Graph: + """ + 一个用于处理图数据库和语言模型的类。 + + 该类通过连接到Neo4j图数据库,并使用语言模型(LLM)和嵌入模型(Embedding)来处理文档和图数据。 + 它还管理一个工作目录,用于存储文档、文档块和社区数据。 + """ + + def __init__( + self, + url: str, # Neo4j数据库的URL + username: str, # Neo4j数据库的用户名 + password: str, # Neo4j数据库的密码 + llm: BaseLLM, # 语言模型(LLM)实例 + emb: BaseLLM, # 嵌入模型(Embedding)实例 + working_dir: str = "workspace", # 工作目录,默认为"workspace" + ): + """ + 初始化Tiny-Graph类。 + + 参数: + - url: Neo4j数据库的URL + - username: Neo4j数据库的用户名 + - password: Neo4j数据库的密码 + - llm: 语言模型(LLM)实例 + - emb: 嵌入模型(Embedding)实例 + - working_dir: 工作目录,默认为"workspace" + """ + self.driver = driver = GraphDatabase.driver( + url, auth=(username, password) + ) # 创建Neo4j数据库驱动 + self.llm = llm # 设置语言模型 + self.embedding = emb # 设置嵌入模型 + self.working_dir = working_dir # 设置工作目录 + os.makedirs(self.working_dir, exist_ok=True) # 创建工作目录(如果不存在) + + # 定义文档、文档块和社区数据的文件路径 + self.doc_path = os.path.join(working_dir, "doc.txt") + self.chunk_path = os.path.join(working_dir, "chunk.json") + self.community_path = os.path.join(working_dir, "community.json") + + # 创建文件(如果不存在) + create_file_if_not_exists(self.doc_path) + create_file_if_not_exists(self.chunk_path) + create_file_if_not_exists(self.community_path) + + # 加载已加载的文档 + self.loaded_documents = self.get_loaded_documents() + ``` +##### 1.3 添加文档到图数据库 +- 使用Tiny-Graph类下的`add_document`方法来将指定路径的文档添加到图数据库中`graph.add_document("example/data.md")`。该方法会自动处理文档的分块和嵌入生成,并将结果存储在图数据库中。这里的路径是相对路径,指向当前工作目录下的example/data.md文件。其主要功能如下: +###### 1.3.1 检查文档是否已经分块; +- 对应代码 + ```python + # ================ Check if the document has been loaded ================ + if filepath in self.get_loaded_documents(): + print( + f"Document '{filepath}' has already been loaded, skipping import process." + ) + return # 在这段代码中,return 的作用是 终止函数的执行,并返回到调用该函数的地方 + ``` +- 功能:检查指定文档是否已经被加载过,避免重复处理。 +- 实现步骤: + 1. 调用`self.get_loaded_documents()`方法,读取已加载文档的缓存文件(doc.txt),返回一个包含已加载文档路径的集合 + 2. 检查文档路径是否已经存在,如果已经存在,则打印提示信息 + 3. 中止函数的执行,return在此段代码中的作用是中止函数的执行,并返回到调用该函数的地方。 + 2. 将文档分割成块(此处就是分割为json格式的文件); + + +###### 1.3.2. 将文档分割成块(此处就是分割为json格式的文件) +- 对应代码 + ```python + # ================ Chunking ================ + chunks = self.split_text(filepath) + existing_chunks = read_json_file(self.chunk_path) + + # Filter out chunks that are already in storage + new_chunks = {k: v for k, v in chunks.items() if k not in existing_chunks} + + if not new_chunks: + print("All chunks are already in the storage.") + return + + # Merge new chunks with existing chunks + all_chunks = {**existing_chunks, **new_chunks} + write_json_file(all_chunks, self.chunk_path) + print(f"Document '{filepath}' has been chunked.") + ``` +- 功能:将文档分割成多个小块(chunks),并将这些分块存储到一个JSON文件中,避免重复存储已经存在的分块。 +- 实现步骤: + 1. 分割文档:调用`chunks = self.split_text(filepath)`方法,将文档分割成多个小块,并且相邻小块之间有一定重叠,返回值chunks是一个字典,键是分块的唯一ID,值是分块的内容。 + 2. 读取已经存储的分块:`existing_chunks = read_json_file(self.chunk_path)`,调用该方法从chunk.json中读取已经存储的分块,返回值existing_chunks是一个字典,包含所有已经存储的分块。 + 3. 过渡新分块:`new_chunks = {k: v for k, v in chunks.items() if k not in existing_chunks}`,使用字典推导式过滤出新的分块,返回值new_chunks是一个字典,包含所有新的分块。 + 4. 检查是否有新的分块,如果new_chunks为空,也就是没有新的分块需要存储的话,打印提示信息并终止函数执行。 + ```python + if not new_chunks: + print("All chunks are already in the storage.") + return + ``` + 5. 合并分块:`all_chunks = {**existing_chunks, **new_chunks}`,使用字典包语法将existing_chunks和new_chunks合并为一个新的字典。 + 6. 写入JSON文件:`write_json_file(all_chunks, self.chunk_path)`,将合并后的分块写入chunk.json文件。 + 7. 打印提示信息。 + +###### 1.3.3 从块中提取实体(entities)和三元组(triplets); +- 对应代码 + ```python + # ================ Entity Extraction ================ + all_entities = []# 用于存储从文档块中提取的实体 + all_triplets = []# 用于存储从所有文档中提取的三元组 + + # 遍历文档块,每个分块有一个唯一的chunk_id和对应的内容chunk_content + for chunk_id, chunk_content in tqdm( + new_chunks.items(), desc=f"Processing '{filepath}'" + ): + try: + # 从当前分块中提取实体,每个实体包含名称、描述、关联的分块ID以及唯一的实体ID + entities = self.get_entity(chunk_content, chunk_id=chunk_id) + all_entities.extend(entities) + # 从当前分块中提取三元组,每个三元组由主语(subject)、谓语(predicate)和宾语(object)组成,表示实体之间的关系 + triplets = self.get_triplets(chunk_content, entities) + all_triplets.extend(triplets) + except: + print( + f"An error occurred while processing chunk '{chunk_id}'. SKIPPING..." + ) + + print( + f"{len(all_entities)} entities and {len(all_triplets)} triplets have been extracted." + ) + ``` +- 功能:遍历文档块以及从当前分块中提取实体和三元组,其中提取实体和三元组,均使用llm来完成。下面分析这代代码的实现步骤,再简单解释下实体和三元组的定义与结构。 +- 实现步骤: + 1. 初始化存储容器; + 2. 遍历文档块,遍历new_chunks字典,其中每一块有一个chunk_id和对应的内容chunk_content。 + 3. 提取实体和三元组:首先调用self.get_entity(chunk_conten, chunk_id= chunk_id)方法,从当前分块中提取实体,将提取到的实体追加到all_entities列表中;然后调用self.get_triplets(chunk_content, entities)方法,从当前分块中提取三元组,将提取到的三元组追加到all_triplets列表中。如果在处理过程中出现错误,打印错误信息并跳过该分块。 + 4. 打印提取的实体和三元组综述,便于检查和提取结果。 +- 实体的定义与结构 + - 定义:实体是文档中提取的关键概念或对象,通常是名词或专有名词。 + - 结构示意 + ```python + { + "name": "Entity Name", # 实体名称 + "description": "Entity Description", # 实体描述 + "chunks id": ["chunk-1a2b3c"], # 关联的文档块 ID + "entity id": "entity-123456" # 实体的唯一标识符 + } + ``` +- 三元组的定义与结构 + - 定义:三元组是描述实体之间关系的结构,包含主语(subject)、谓语(predicate)和宾语(object)。 + - 结构示意 + ```python + { + "subject": "Subject Name", # 主语名称 + "subject_id": "entity-123456", # 主语的唯一标识符 + "predicate": "Predicate Name", # 谓语(关系名称) + "object": "Object Name", # 宾语名称 + "object_id": "entity-654321" # 宾语的唯一标识符 + } + ``` + +- 实体(Entities)是图数据库中的节点,表示文档中的关键概念。三元组(Triplets)是+图数据库中的边,表示实体之间的关系,Neo4j中的节点与三元组关系如下所示: + +
+ +
+ +###### 1.3.4 执行实体消歧和三元组更新。实体消歧有两种方法可以选择,默认将同名实体认为是同一实体 +- 对应代码 + ```python + # ================ Entity Disambiguation ================ + entity_names = list(set(entity["name"] for entity in all_entities)) + + if use_llm_deambiguation: + entity_id_mapping = {} + for name in entity_names: + same_name_entities = [ + entity for entity in all_entities if entity["name"] == name + ] + transform_text = self.llm.predict( + ENTITY_DISAMBIGUATION.format(same_name_entities) + ) + entity_id_mapping.update( + get_text_inside_tag(transform_text, "transform") + ) + else: + entity_id_mapping = {} + for entity in all_entities: + entity_name = entity["name"] + if entity_name not in entity_id_mapping: + entity_id_mapping[entity_name] = entity["entity id"] + + for entity in all_entities: + entity["entity id"] = entity_id_mapping.get( + entity["name"], entity["entity id"] + ) + + triplets_to_remove = [ + triplet + for triplet in all_triplets + if entity_id_mapping.get(triplet["subject"], triplet["subject_id"]) is None + or entity_id_mapping.get(triplet["object"], triplet["object_id"]) is None + ] + + updated_triplets = [ + { + **triplet, + "subject_id": entity_id_mapping.get( + triplet["subject"], triplet["subject_id"] + ), + "object_id": entity_id_mapping.get( + triplet["object"], triplet["object_id"] + ), + } + for triplet in all_triplets + if triplet not in triplets_to_remove + ] + all_triplets = updated_triplets + ``` +- 对于实体消歧(Entity Disambiguation)部分 + - 功能: + - 解决同名实体歧义的问题,确保每个实体都有唯一的entity_id。如果启用了LLM消歧(use_llm_deambiguation=True),则默认将同名实体视为同一实体;如果未启用LLM消歧,则默认将同名实体视为同一实体。本项目采用后者。 + - 实现步骤: + 1. 提取实体的名称存储到entity_names中; + 2. 使用默认方法消歧义 + 3. 更新实体ID +- 对于三元组更新(Triplet Update)部分 + - 功能: + - 根据消歧后的实体ID更新三元组,并移除无效的三元组。 + - 实现步骤: + 1. 移除所有无效的三元组(如果三元组的主语或者宾语的实体ID无法在entity_id_mapping中找到,则将其标记为无效); + 2. 更新三元组(对于有效的三元组,更新其主语和宾语的实体ID) + 3. 保存更新后的三元组(将更新后的三元组列表保存到all_triplets中) +###### 1.3.5 合并实体和三元组 +- 对应代码 + ```python + # ================ Merge Entities ================ + entity_map = {} + + for entity in all_entities: + entity_id = entity["entity id"] + if entity_id not in entity_map: + entity_map[entity_id] = { + "name": entity["name"], + "description": entity["description"], + "chunks id": [], + "entity id": entity_id, + } + else: + entity_map[entity_id]["description"] += " " + entity["description"] + + entity_map[entity_id]["chunks id"].extend(entity["chunks id"]) + ``` +- 功能: + - 将所有提取的实体(all_entities)按照其唯一标识符(entity_id)进行归并,确保同一个实体的描述和关联的文档块ID被整合到一起 +- 实现步骤: + - 使用一个字典entity_map,以entity_id作为键,存储每个实体的合并信息。如果某个实体entity_id已经存在于entity_map中,则将其描述和文档块ID合并到已有的实体中。 +###### 1.3.6 将合并的实体和三元组存储到Neo4j的图数据库中 +- 对应代码 + ```python + # ================ Store Data in Neo4j ================ + for triplet in all_triplets: + subject_id = triplet["subject_id"] + object_id = triplet["object_id"] + + subject = entity_map.get(subject_id) + object = entity_map.get(object_id) + if subject and object: + self.create_triplet(subject, triplet["predicate"], object) + ``` +- 功能:将提取的三元组(triplets)存储到Neo4j图数据库中 +- 实现步骤: + 1. 遍历all_triplets列表,逐个处理每个三元组 + 2. 根据三元组中的subject_id和object_id,从entity_map中获取对应的实体信息 + 3. 如果主语和宾语实体都存在,则调用self.create_triplet方法,将三元组存储到Neo4j中。其中的create_triplet方法能够通过Cypher查询语句将实体和关系插入到数据库中。 +###### 1.3.7 生成社区内容 + +基于上一步构建的知识图谱索引,可以使用多种社区检测算法对图进行划分,以识别其中强连接的节点集合(即社区)。在我们的处理流程中,我们采用 Leiden 社区检测算法以递归方式构建社区层级结构:首先在全图中识别出初始社区,然后在每个社区内继续执行子社区检测,直到无法进一步划分为止,形成叶级社区。 + +Leiden 算法主要包括以下三个阶段: + + +- 节点聚合: 在固定社区划分的前提下,尝试将每个节点移动到邻居节点所属社区,以提升总体模块度。 +- 社区细化: 对每个社区进行局部划分,确保每个社区子图中的所有节点之间是连通的,防止出现不连通的社区。 +- 图聚合: 构建新的超图,将每个社区作为一个超级节点,重复第一步,形成递归的社区层级结构。 + + +模块度用于衡量当前社区划分相较于随机划分的“好坏”,定义如下: + +$$ +Q = \frac{1}{2m} \sum_{i,j} \left[ A_{ij} - \frac{k_i k_j}{2m} \right] \delta(c_i, c_j) +$$ + +其中: + +$A_{ij}$:节点 $i$ 与节点 $j$ 之间的边的权重; +$k_i$:节点 $i$ 的度(边的总权重); +$m$:图中所有边的总权重的一半(即 $m = \frac{1}{2} \sum_{i,j} A_{ij}$); +$c_i$:节点 $i$ 所属的社区编号; +$\delta(c_i, c_j)$:当 $i$ 与 $j$ 属于同一社区时为 1,否则为 0。 + +在社区划分完成后,我们为社区层级结构中的每一个社区生成类报告形式的摘要。这一过程支持对大规模数据集的结构性理解,提供了不依赖具体查询的语义概览,帮助用户快速掌握语料库中各主题的全貌。 + +例如,用户可以浏览某一层级的社区摘要以确定感兴趣的主题,并进一步阅读下级社区报告,获取更细粒度的信息。尽管这些摘要本身具有独立意义,但我们主要关注其作为图索引机制的一部分,在响应全局性查询时的效用。 + +摘要的生成采用模板方法,逐步将图中节点、边及其声明的摘要填充到社区摘要模板中。较低层级的社区摘要将作为上层社区摘要生成的基础。 + + 对于叶级社区,从图中提取的节点和边的摘要被按优先级排序加入到 LLM 上下文窗口中。排序标准是:依据边的源节点和目标节点的整体度量(即显著性)降序排列。依次添加源节点描述、目标节点描述、边的描述。 + +对于高级社区,若所有元素摘要在上下文窗口的 token 限制内可容纳,则按叶级社区的方法汇总所有元素摘要;否则,将子社区按摘要的 token 数量降序排序,逐步用更短的子社区摘要替换较长的元素摘要,直到整体摘要符合上下文窗口限制。 + + + +- 对应代码 + ```python + # ================ communities ================ + self.gen_community() + self.generate_community_report() + ``` +- 功能: + - 生成社区:通过图算法(本项目为 Leiden 算法)检测图中的社区结构。 + - 生成社区报告:借助大语言模型为每个社区生成详细的报告,描述社区中的实体和关系。 +- 实现步骤: + 1. 对于生成社区功能,调用 self.gen_community() 方法: + - 使用 Neo4j 的图算法(如 gds.leiden.write)检测社区。 + - 生成社区架构(community schema),包括社区的层级、节点、边等信息。 + - 将社区架构存储到 community.json 文件中。 + 2. 对于生成社区报告功能,调用 self.generate_community_report() 方法: + - 遍历每个社区,生成包含实体和关系的报告。 + - 报告通过大语言模型(LLM)生成,描述社区的结构和内容。 + +###### 1.3.8 生成嵌入式向量 +- 对应代码 + ```python + # ================ embedding ================ + self.add_embedding_for_graph() + self.add_loaded_documents(filepath) + print(f"doc '{filepath}' has been loaded.") + ``` +- 功能: + - 为图数据库中的每个实体节点生成嵌入向量(embedding),用于计算相似度(本项目采用余弦相似度)和查询。 + - 将处理过的文档路径记录到缓存文件中,避免重复处理。 +- 实现步骤: + 1. 生成嵌入:调用 self.add_embedding_for_graph() 方法:遍历图数据库中的每个实体节点;使用嵌入模型(self.embedding)计算节点描述的嵌入向量;将嵌入向量存储到节点的 embedding 属性中。 + 2. 记录文档路径:调用 self.add_loaded_documents(filepath) 方法:将当前文路径添加到缓存文件中,避免重复加载。 +- 最终生成的图数据信息如下所示: + +
+ +
+ +###### 1.3.9 验证下数据库连接是否正常(当然,此步也可省略) +- 对应代码 + ```python + with graph.driver.session() as session: + result = session.run("MATCH (n) RETURN count(n) as count") + count = result.single()["count"] + print(f"数据库连接正常,节点数量: {count}") + ``` +#### 2. 图检索优化 +##### 2.1 两种图检索方法 +- 按照Tiny-Graphrag demo代码的执行过程,图检索优化过程有两种:分别为Tiny-Graph类中`local_query`方法和`global_query`方法。 + - 全局查询和局部查询的特点如下表所示: + +| 查询类型 | 特点 | 适用场景 | +|----------|------|----------| +| 全局查询(global_query) | • 基于社区层级评分
• 筛选候选社区
• 返回排序列表 | • 高层次理解
• 全局视角分析 | +| 局部查询(local_query) | • 基于直接关联上下文
• 提取精确实体/关系
• 返回多部分结果 | • 精确定位
• 深度分析 | +- 下面依次分析下`local_query`方法和`global_query`方法的具体实现过程。 +##### 2.2 local_query方法 + +局部查询方法主要用于回答那些聚焦于单一或少数几个实体的问题,比如“孙悟空的生平”或“矢车菊的治疗特性”。这种方法通过一系列步骤,从知识图谱和原始语料中提取与查询密切相关的信息,以构建精准的上下文,并最终生成高质量的回答。 + +首先,系统会将用户的查询转换为一个向量表示,用于捕捉其语义含义。接着,它会在知识图谱中为每个实体节点也生成相应的向量表示,然后通过计算它们与查询向量之间的相似度,筛选出那些与查询密切相关的实体。如果它们之间的相似度超过一个设定的阈值,就会被认为是“相关实体”。 + +找到这些关键实体之后,系统会进一步扩展它们的上下文信息,包括提取它们在知识图谱中直接连接的邻居节点和边,以补充相关的结构性信息。同时,系统还会在原始文本语料中查找与这些实体强相关的内容片段,为后续生成提供更多语义背景。 + +最后,这些相关实体、邻居节点以及对应的文本片段会被组合成一个紧凑的局部上下文窗口,这个窗口会被输入到 LLM 中,用来生成针对用户问题的具体回答。 + +- 在Tiny_Graphrag_test.ipynb中,执行局部查询测试时,使用的是local_query方法 + - 具体代码为:`local_res = graph.local_query("what is dl?")` + - 其中调用的方法`local_query("what is dl?")`,将"what is dl?"传递给`local_query()`方法,以下是`local_query()`方法的代码内容和代码解读 +- 代码内容 + ```python + def local_query(self, query): + context = self.build_local_query_context(query) # 分别包括社区、实体、关系、文档块这四部分 + prompt = LOCAL_QUERY.format(query=query, context=context) # 需要的参数context以及query都在该方法内得到了 + response = self.llm.predict(prompt) + return response + ``` +- 代码解读 + - 执行`context = self.build_local_query_context(query)`后,根据用户问题(本项目中是"what is dl?")得到的包括社区、实体、关系、文档块的这四部分内容的上下文(context)。得到上下文的具体方法为:`build_local_query_context(self, query)`,该方法内的代码执行顺序是: + 1. 获得输入文本的嵌入向量。对应代码为:`query_emb = self.embedding.get_emb(query)` + 2. 获得前k个最相似的实体,相似判断的依据是余弦相似度。对应代码为:`topk_similar_entities_context = self.get_topk_similar_entities(query_emb)` + 3. 获得前k个最相似的社区,依据的方法是: + - 利用上面得到的最相似实体; + - 只要包含上述的任意topk节点,就认为是相似社区(社区:community,由相互关联的节点组成的集合)。对应代码为: + ```python + topk_similar_communities_context = self.get_communities( + topk_similar_entities_context + ) + ``` + 4. 获得前k个最相似的关系,依据的方法是:在`get_relations`方法中调用`get_node_edgs`方法,获取该实体的所有关系边,认为这些边就是similar relation。对应代码为: + ```python + topk_similar_relations_context = self.get_relations( + topk_similar_entities_context, query + ) + ``` + 5. 获得前k个最相似的文档块,依据是方法是:在`get_chunks`方法中调用`get_node_chunks`方法,获取该实体关联的文档块,认为这些文档块就是similar chunks。对应代码为: + ```python + topk_similar_chunks_context = self.get_chunks( + topk_similar_entities_context, query + ) + ``` + 6. `build_local_query_context()`方法最终返回的是一个多行字符串,包括: + - Reports:社区报告; + - Entities:与查询最相似的实体; + - Relationships:这些实体之间的关系; + - Sources:这些实体关联的文档块。对应的代码为: + ```python + return f""" + -----Reports----- + ```csv + {topk_similar_communities_context} + ``` + -----Entities----- + ```csv + {topk_similar_entities_context} + ``` + -----Relationships----- + ```csv + {topk_similar_relations_context} + ``` + -----Sources----- + ```csv + {topk_similar_chunks_context} + ``` + """ + ``` + - 之后的`prompt = LOCAL_QUERY.format(query=query, context=context)`可以理解为根据刚刚生成的context作为上下文,生成prompt为大模型使用。 + - 最后 ,`response = self.llm.predict(prompt)`是将上文得到的prompt传输给大模型,从而让大模型做推理和回答,然后该方法返回到`response(return response)`作为大模型的回答结果。 +###### 2.3 global_query方法 + +全局查询方法适用于更复杂的问题,尤其是那些需要跨越多个知识图谱社区、结构性较强的查询,比如“曹操与大乔之间的联系”。这种类型的问题通常难以通过关注单一实体来解决,因此需要更宏观的视角和层级化的信息整合。 + +整个流程围绕社区结构展开。知识图谱被组织成多层的社区集合,每一层的社区代表一组语义相关的实体节点,每个社区都有一个由 LLM 生成的摘要,简要概括了该社区的主要信息。 + +处理这类查询时,系统首先将用户提出的问题转换为一个向量表示,用于捕捉其深层语义。随后,它会将这个向量与所有社区摘要的嵌入进行比较,筛选出与查询最相关的一组社区。这一筛选基于相似度阈值,确保只保留与查询密切相关的区域。 + +接下来,系统会把这些相关社区的摘要进一步切分成较小的文本块,每一块单独输入到语言模型中进行处理。模型会为每个文本块生成一个中间响应,识别出若干关键信息点,并为每个信息点打分,以反映其对回答问题的贡献度。 + +然后,系统会根据评分,从所有中间响应中挑选出最重要的若干信息点,组成一个高质量的全局上下文。这些信息点跨越不同的社区,构成了一个面向复杂查询的知识核心。 + +最后,这个上下文连同原始问题一起被输入到语言模型中,生成最终的答案。通过这种方式,全局查询不仅能覆盖广泛的实体与关系,还能整合跨社区的背景信息,提供更深入、综合的回答。 + + +- 在`Tiny_Graphrag_test.ipynb`中,执行全局查询测试时,使用的是`global_query`方法 + - 具体代码为:`global_res = graph.global_query("what is dl?")` + - 其中调用的方法`global_query("what is dl?")`,将"what is dl?"传递给`global_query()`方法,以下是`global_query()`方法的代码内容和代码解读 +- 代码内容: + ```python + def global_query(self, query, level=1): + context = self.build_global_query_context(query, level) # 得到的是一个列表,包含社区的描述和分数 + prompt = GLOBAL_QUERY.format(query=query, context=context)# 将得到的context传入到prompt中 + response = self.llm.predict(prompt)# 将prompt传入到llm中,得到最终的结果,也就是将包含描述和分数的列表传入到llm中 + return response + ``` +- 代码解读: +- 运行`context = self.build_global_query_context(query, level)`时,会根据用户问题(本项目中是“what is dl?")得到的包含社区描述和分数的上下文(context)。对应代码为:`context = self.build_global_query_context(query, level)`,该方法内的代码执行顺序是: + 1. 设定空的候选社区(字典)以及空的社区评分列表(列表),并筛选符合层级要求的社区。对应代码为: + ```python + communities_schema = self.read_community_schema() + candidate_community = {} # 候选社区 字典 + points = [] # 社区评分列表 列表 + # 筛选符合层级要求的社区 + for communityid, community_info in communities_schema.items(): + if community_info["level"] < level: + candidate_community.update({communityid: community_info}) + ``` + 2. 计算候选的社区的评分,通过调用`map_community_points`函数,结合社区报告和大语言模型的能力,为每个候选社区生成与查询内容(如 "What is DL?")相关程度的评分。对应的代码为: + ```python + for communityid, community_info in candidate_community.items(): + points.extend(self.map_community_points(community_info["report"], query)) + ``` + 3. 按照评分降序排序,得到包含描述和分数的列表。描述是社区的描述,分数是查询的相关性得分。对应代码为: + ```python + points = sorted(points, key=lambda x: x[-1], reverse=True) + return points # 得到包含描述和分数的列表,描述是社区的描述,分数是查询的相关性得分 + ``` + 4. 之后的`prompt = GLOBAL_QUERY.format(query=query, context=context)`可以理解为根据刚刚生成的context作为上下文,生成prompt给大模型使用。 + 5. 最后,`response=self.llm.predict(prompt)`将上文得到的prompt传输给大模型。`return response`作为大模型的回答结果。 +##### 2.4 生成增强 +1. 通俗来讲就是:将得到的上下文输入给大模型,基于此上下文,大模型作推理和回答 +2. 在本项目代码中,`local_query`方法和`global_query`方法的将各自得到的上下文传输给大模型将是生成增强的过程。 + - 局部查询和全局查询成功运行的示例: + +
+ +
+ diff --git a/content/TinyGraphRAG/tinygraph/embedding/__init__.py b/content/TinyGraphRAG/tinygraph/embedding/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/content/TinyGraphRAG/tinygraph/embedding/base.py b/content/TinyGraphRAG/tinygraph/embedding/base.py new file mode 100644 index 0000000..cfe15ee --- /dev/null +++ b/content/TinyGraphRAG/tinygraph/embedding/base.py @@ -0,0 +1,25 @@ +from abc import ABC, abstractmethod +from typing import List, Any, Optional + + +class BaseEmb(ABC): + def __init__( + self, + model_name: str, + model_params: Optional[dict[str, Any]] = None, + **kwargs: Any, + ): + self.model_name = model_name + self.model_params = model_params or {} + + @abstractmethod + def get_emb(self, input: str) -> List[float]: + """Sends a text input to the embedding model and retrieves the embedding. + + Args: + input (str): Text sent to the embedding model + + Returns: + List[float]: The embedding vector from the model. + """ + pass diff --git a/content/TinyGraphRAG/tinygraph/embedding/zhipu.py b/content/TinyGraphRAG/tinygraph/embedding/zhipu.py new file mode 100644 index 0000000..cef5dbb --- /dev/null +++ b/content/TinyGraphRAG/tinygraph/embedding/zhipu.py @@ -0,0 +1,16 @@ +from zhipuai import ZhipuAI +from typing import List +from .base import BaseEmb + + +class zhipuEmb(BaseEmb): + def __init__(self, model_name: str, api_key: str, **kwargs): + super().__init__(model_name=model_name, **kwargs) + self.client = ZhipuAI(api_key=api_key) + + def get_emb(self, text: str) -> List[float]: + emb = self.client.embeddings.create( + model=self.model_name, + input=text, + ) + return emb.data[0].embedding diff --git a/content/TinyGraphRAG/tinygraph/graph.py b/content/TinyGraphRAG/tinygraph/graph.py new file mode 100644 index 0000000..5313a32 --- /dev/null +++ b/content/TinyGraphRAG/tinygraph/graph.py @@ -0,0 +1,714 @@ +from neo4j import GraphDatabase +import os +from tqdm import tqdm +from .utils import ( + get_text_inside_tag, + cosine_similarity, + compute_mdhash_id, + read_json_file, + write_json_file, + create_file_if_not_exists, +) +from .llm.base import BaseLLM +from .embedding.base import BaseEmb +from .prompt import * +from typing import Dict, List, Optional, Tuple, Union +import numpy as np +from collections import defaultdict +import json + +from dataclasses import dataclass + + +@dataclass +class Node: + name: str + desc: str + chunks_id: list + entity_id: str + similarity: float + + +class TinyGraph: + """ + 一个用于处理图数据库和语言模型的类。 + + 该类通过连接到Neo4j图数据库,并使用语言模型(LLM)和嵌入模型(Embedding)来处理文档和图数据。 + 它还管理一个工作目录,用于存储文档、文档块和社区数据。 + """ + + def __init__( + self, + url: str, # Neo4j数据库的URL + username: str, # Neo4j数据库的用户名 + password: str, # Neo4j数据库的密码 + llm: BaseLLM, # 语言模型(LLM)实例 + emb: BaseLLM, # 嵌入模型(Embedding)实例 + working_dir: str = "workspace", # 工作目录,默认为"workspace" + ): + """ + 初始化TinyGraph类。 + + 参数: + - url: Neo4j数据库的URL + - username: Neo4j数据库的用户名 + - password: Neo4j数据库的密码 + - llm: 语言模型(LLM)实例 + - emb: 嵌入模型(Embedding)实例 + - working_dir: 工作目录,默认为"workspace" + """ + self.driver = driver = GraphDatabase.driver( + url, auth=(username, password) + ) # 创建Neo4j数据库驱动 + self.llm = llm # 设置语言模型 + self.embedding = emb # 设置嵌入模型 + self.working_dir = working_dir # 设置工作目录 + os.makedirs(self.working_dir, exist_ok=True) # 创建工作目录(如果不存在) + + # 定义文档、文档块和社区数据的文件路径 + self.doc_path = os.path.join(working_dir, "doc.txt") + self.chunk_path = os.path.join(working_dir, "chunk.json") + self.community_path = os.path.join(working_dir, "community.json") + + # 创建文件(如果不存在) + create_file_if_not_exists(self.doc_path) + create_file_if_not_exists(self.chunk_path) + create_file_if_not_exists(self.community_path) + + # 加载已加载的文档 + self.loaded_documents = self.get_loaded_documents() + + def create_triplet(self, subject: dict, predicate, object: dict) -> None: + """ + 创建一个三元组(Triplet)并将其存储到Neo4j数据库中。 + + 参数: + - subject: 主题实体的字典,包含名称、描述、块ID和实体ID + - predicate: 关系名称 + - object: 对象实体的字典,包含名称、描述、块ID和实体ID + + 返回: + - 查询结果 + """ + # 定义Cypher查询语句,用于创建或合并实体节点和关系 + query = ( + "MERGE (a:Entity {name: $subject_name, description: $subject_desc, chunks_id: $subject_chunks_id, entity_id: $subject_entity_id}) " + "MERGE (b:Entity {name: $object_name, description: $object_desc, chunks_id: $object_chunks_id, entity_id: $object_entity_id}) " + "MERGE (a)-[r:Relationship {name: $predicate}]->(b) " + "RETURN a, b, r" + ) + + # 使用数据库会话执行查询 + with self.driver.session() as session: + result = session.run( + query, + subject_name=subject["name"], + subject_desc=subject["description"], + subject_chunks_id=subject["chunks id"], + subject_entity_id=subject["entity id"], + object_name=object["name"], + object_desc=object["description"], + object_chunks_id=object["chunks id"], + object_entity_id=object["entity id"], + predicate=predicate, + ) + + return + + def split_text(self,file_path:str, segment_length=300, overlap_length=50) -> Dict: + """ + 将文本文件分割成多个片段,每个片段的长度为segment_length,相邻片段之间有overlap_length的重叠。 + + 参数: + - file_path: 文本文件的路径 + - segment_length: 每个片段的长度,默认为300 + - overlap_length: 相邻片段之间的重叠长度,默认为50 + + 返回: + - 包含片段ID和片段内容的字典 + """ + chunks = {} # 用于存储片段的字典 + with open(file_path, "r", encoding="utf-8") as file: + content = file.read() # 读取文件内容 + + text_segments = [] # 用于存储分割后的文本片段 + start_index = 0 # 初始化起始索引 + + # 循环分割文本,直到剩余文本长度不足以形成新的片段 + while start_index + segment_length <= len(content): + text_segments.append(content[start_index : start_index + segment_length]) + start_index += segment_length - overlap_length # 更新起始索引,考虑重叠长度 + + # 处理剩余的文本,如果剩余文本长度小于segment_length但大于0 + if start_index < len(content): + text_segments.append(content[start_index:]) + + # 为每个片段生成唯一的ID,并将其存储在字典中 + for segement in text_segments: + chunks.update({compute_mdhash_id(segement, prefix="chunk-"): segement}) + + return chunks + + def get_entity(self, text: str, chunk_id: str) -> List[Dict]: + """ + 从给定的文本中提取实体,并为每个实体生成唯一的ID和描述。 + + 参数: + - text: 输入的文本 + - chunk_id: 文本块的ID + + 返回: + - 包含提取的实体信息的列表 + """ + # 使用语言模型预测实体信息 + data = self.llm.predict(GET_ENTITY.format(text=text)) + concepts = [] # 用于存储提取的实体信息 + + # 从预测结果中提取实体信息 + for concept_html in get_text_inside_tag(data, "concept"): + concept = {} + concept["name"] = get_text_inside_tag(concept_html, "name")[0].strip() + concept["description"] = get_text_inside_tag(concept_html, "description")[ + 0 + ].strip() + concept["chunks id"] = [chunk_id] + concept["entity id"] = compute_mdhash_id( + concept["description"], prefix="entity-" + ) + concepts.append(concept) + + return concepts + + def get_triplets(self, content, entity: list) -> List[Dict]: + """ + 从给定的内容中提取三元组(Triplet)信息,并返回包含这些三元组信息的列表。 + + 参数: + - content: 输入的内容 + - entity: 实体列表 + + 返回: + - 包含提取的三元组信息的列表 + """ + try: + # 使用语言模型预测三元组信息 + data = self.llm.predict(GET_TRIPLETS.format(text=content, entity=entity)) + data = get_text_inside_tag(data, "triplet") + except Exception as e: + print(f"Error predicting triplets: {e}") + return [] + + res = [] # 用于存储提取的三元组信息 + + # 从预测结果中提取三元组信息 + for triplet_data in data: + try: + subject = get_text_inside_tag(triplet_data, "subject")[0] + subject_id = get_text_inside_tag(triplet_data, "subject_id")[0] + predicate = get_text_inside_tag(triplet_data, "predicate")[0] + object = get_text_inside_tag(triplet_data, "object")[0] + object_id = get_text_inside_tag(triplet_data, "object_id")[0] + res.append( + { + "subject": subject, + "subject_id": subject_id, + "predicate": predicate, + "object": object, + "object_id": object_id, + } + ) + except Exception as e: + print(f"Error extracting triplet: {e}") + continue + + return res + + def add_document(self, filepath, use_llm_deambiguation=False) -> None: + """ + 将文档添加到系统中,执行以下步骤: + 1. 检查文档是否已经加载。 + 2. 将文档分割成块。 + 3. 从块中提取实体和三元组。 + 4. 执行实体消岐,有两种方法可选,默认将同名实体认为即为同一实体。 + 5. 合并实体和三元组。 + 6. 将合并的实体和三元组存储到Neo4j数据库中。 + + 参数: + - filepath: 要添加的文档的路径 + - use_llm_deambiguation: 是否使用LLM进行实体消岐 + """ + # ================ Check if the document has been loaded ================ + if filepath in self.get_loaded_documents(): + print( + f"Document '{filepath}' has already been loaded, skipping import process." + ) + return + + # ================ Chunking ================ + chunks = self.split_text(filepath) + existing_chunks = read_json_file(self.chunk_path) + + # Filter out chunks that are already in storage + new_chunks = {k: v for k, v in chunks.items() if k not in existing_chunks} + + if not new_chunks: + print("All chunks are already in the storage.") + return + + # Merge new chunks with existing chunks + all_chunks = {**existing_chunks, **new_chunks} + write_json_file(all_chunks, self.chunk_path) + print(f"Document '{filepath}' has been chunked.") + + # ================ Entity Extraction ================ + all_entities = [] + all_triplets = [] + + for chunk_id, chunk_content in tqdm( + new_chunks.items(), desc=f"Processing '{filepath}'" + ): + try: + entities = self.get_entity(chunk_content, chunk_id=chunk_id) + all_entities.extend(entities) + triplets = self.get_triplets(chunk_content, entities) + all_triplets.extend(triplets) + except: + print( + f"An error occurred while processing chunk '{chunk_id}'. SKIPPING..." + ) + + print( + f"{len(all_entities)} entities and {len(all_triplets)} triplets have been extracted." + ) + # ================ Entity Disambiguation ================ + entity_names = list(set(entity["name"] for entity in all_entities)) + + if use_llm_deambiguation: + entity_id_mapping = {} + for name in entity_names: + same_name_entities = [ + entity for entity in all_entities if entity["name"] == name + ] + transform_text = self.llm.predict( + ENTITY_DISAMBIGUATION.format(same_name_entities) + ) + entity_id_mapping.update( + get_text_inside_tag(transform_text, "transform") + ) + else: + entity_id_mapping = {} + for entity in all_entities: + entity_name = entity["name"] + if entity_name not in entity_id_mapping: + entity_id_mapping[entity_name] = entity["entity id"] + + for entity in all_entities: + entity["entity id"] = entity_id_mapping.get( + entity["name"], entity["entity id"] + ) + + triplets_to_remove = [ + triplet + for triplet in all_triplets + if entity_id_mapping.get(triplet["subject"], triplet["subject_id"]) is None + or entity_id_mapping.get(triplet["object"], triplet["object_id"]) is None + ] + + updated_triplets = [ + { + **triplet, + "subject_id": entity_id_mapping.get( + triplet["subject"], triplet["subject_id"] + ), + "object_id": entity_id_mapping.get( + triplet["object"], triplet["object_id"] + ), + } + for triplet in all_triplets + if triplet not in triplets_to_remove + ] + all_triplets = updated_triplets + + # ================ Merge Entities ================ + entity_map = {} + + for entity in all_entities: + entity_id = entity["entity id"] + if entity_id not in entity_map: + entity_map[entity_id] = { + "name": entity["name"], + "description": entity["description"], + "chunks id": [], + "entity id": entity_id, + } + else: + entity_map[entity_id]["description"] += " " + entity["description"] + + entity_map[entity_id]["chunks id"].extend(entity["chunks id"]) + # ================ Store Data in Neo4j ================ + for triplet in all_triplets: + subject_id = triplet["subject_id"] + object_id = triplet["object_id"] + + subject = entity_map.get(subject_id) + object = entity_map.get(object_id) + if subject and object: + self.create_triplet(subject, triplet["predicate"], object) + # ================ communities ================ + self.gen_community() + self.generate_community_report() + # ================ embedding ================ + self.add_embedding_for_graph() + self.add_loaded_documents(filepath) + print(f"doc '{filepath}' has been loaded.") + + def detect_communities(self) -> None: + query = """ + CALL gds.graph.project( + 'graph_help', + ['Entity'], + { + Relationship: { + orientation: 'UNDIRECTED' + } + } + ) + """ + with self.driver.session() as session: + result = session.run(query) + + query = """ + CALL gds.leiden.write('graph_help', { + writeProperty: 'communityIds', + includeIntermediateCommunities: True, + maxLevels: 10, + tolerance: 0.0001, + gamma: 1.0, + theta: 0.01 + }) + YIELD communityCount, modularity, modularities + """ + with self.driver.session() as session: + result = session.run(query) + for record in result: + print( + f"社区数量: {record['communityCount']}, 模块度: {record['modularity']}" + ) + session.run("CALL gds.graph.drop('graph_help')") + + def get_entity_by_name(self, name): + query = """ + MATCH (n:Entity {name: $name}) + RETURN n + """ + with self.driver.session() as session: + result = session.run(query, name=name) + entities = [record["n"].get("name") for record in result] + return entities[0] + + def get_node_edgs(self, node: Node): + query = """ + MATCH (n)-[r]-(m) + WHERE n.entity_id = $id + RETURN n.name AS n,r.name AS r,m.name AS m + """ + with self.driver.session() as session: + result = session.run(query, id=node.entity_id) + edges = [(record["n"], record["r"], record["m"]) for record in result] + return edges + + def get_node_chunks(self, node): + existing_chunks = read_json_file(self.chunk_path) + chunks = [existing_chunks[i] for i in node.chunks_id] + return chunks + + def add_embedding_for_graph(self): + query = """ + MATCH (n) + RETURN n + """ + with self.driver.session() as session: + result = session.run(query) + for record in result: + node = record["n"] + description = node["description"] + id = node["entity_id"] + embedding = self.embedding.get_emb(description) + # 更新节点,添加新的 embedding 属性 + update_query = """ + MATCH (n {entity_id: $id}) + SET n.embedding = $embedding + """ + session.run(update_query, id=id, embedding=embedding) + + def get_topk_similar_entities(self, input_emb, k=1) -> List[Node]: + res = [] + query = """ + MATCH (n) + RETURN n + """ + with self.driver.session() as session: + result = session.run(query) + # 如果遇到报错:ResultConsumedError: The result has been consumed. Fetch all needed records before calling Result.consume().可将result = session.run(query)修改为result = list(session.run(query)) + for record in result: + node = record["n"] + if node["embedding"] is not None: + similarity = cosine_similarity(input_emb, node["embedding"]) + node = Node( + name=node["name"], + desc=node["description"], + chunks_id=node["chunks_id"], + entity_id=node["entity_id"], + similarity=similarity, + ) + res.append(node) + return sorted(res, key=lambda x: x.similarity, reverse=True)[:k] + + def get_communities(self, nodes: List[Node]): + communities_schema = self.read_community_schema() + res = [] + nodes_ids = [i.entity_id for i in nodes] + for community_id, community_info in communities_schema.items(): + if set(nodes_ids) & set(community_info["nodes"]): + res.append( + { + "community_id": community_id, + "community_info": community_info["report"], + } + ) + return res + + def get_relations(self, nodes: List, input_emb): + res = [] + for i in nodes: + res.append(self.get_node_edgs(i)) + return res + + def get_chunks(self, nodes, input_emb): + chunks = [] + for i in nodes: + chunks.append(self.get_node_chunks(i)) + return chunks + + def gen_community_schema(self) -> dict[str, dict]: + results = defaultdict( + lambda: dict( + level=None, + title=None, + edges=set(), + nodes=set(), + chunk_ids=set(), + sub_communities=[], + ) + ) + + with self.driver.session() as session: + # Fetch community data + result = session.run( + f""" + MATCH (n:Entity) + WITH n, n.communityIds AS communityIds, [(n)-[]-(m:Entity) | m.entity_id] AS connected_nodes + RETURN n.entity_id AS node_id, + communityIds AS cluster_key, + connected_nodes + """ + ) + + max_num_ids = 0 + for record in result: + for index, c_id in enumerate(record["cluster_key"]): + node_id = str(record["node_id"]) + level = index + cluster_key = str(c_id) + connected_nodes = record["connected_nodes"] + + results[cluster_key]["level"] = level + results[cluster_key]["title"] = f"Cluster {cluster_key}" + results[cluster_key]["nodes"].add(node_id) + results[cluster_key]["edges"].update( + [ + tuple(sorted([node_id, str(connected)])) + for connected in connected_nodes + if connected != node_id + ] + ) + for k, v in results.items(): + v["edges"] = [list(e) for e in v["edges"]] + v["nodes"] = list(v["nodes"]) + v["chunk_ids"] = list(v["chunk_ids"]) + for cluster in results.values(): + cluster["sub_communities"] = [ + sub_key + for sub_key, sub_cluster in results.items() + if sub_cluster["level"] > cluster["level"] + and set(sub_cluster["nodes"]).issubset(set(cluster["nodes"])) + ] + + return dict(results) + + def gen_community(self): + self.detect_communities() + community_schema = self.gen_community_schema() + with open(self.community_path, "w", encoding="utf-8") as file: + json.dump(community_schema, file, indent=4) + + def read_community_schema(self) -> dict: + try: + with open(self.community_path, "r", encoding="utf-8") as file: + community_schema = json.load(file) + except: + raise FileNotFoundError( + "Community schema not found. Please make sure to generate it first." + ) + return community_schema + + def get_loaded_documents(self): + try: + with open(self.doc_path, "r", encoding="utf-8") as file: + lines = file.readlines() + return set(line.strip() for line in lines) + except: + raise FileNotFoundError("Cache file not found.") + + def add_loaded_documents(self, file_path): + if file_path in self.loaded_documents: + print( + f"Document '{file_path}' has already been loaded, skipping addition to cache." + ) + return + with open(self.doc_path, "a", encoding="utf-8") as file: + file.write(file_path + "\n") + self.loaded_documents.add(file_path) + + def get_node_by_id(self, node_id): + query = """ + MATCH (n:Entity {entity_id: $node_id}) + RETURN n + """ + with self.driver.session() as session: + result = session.run(query, node_id=node_id) + nodes = [record["n"] for record in result] + return nodes[0] + + def get_edges_by_id(self, src, tar): + query = """ + MATCH (n:Entity {entity_id: $src})-[r]-(m:Entity {entity_id: $tar}) + RETURN {src: n.name, r: r.name, tar: m.name} AS R + """ + with self.driver.session() as session: + result = session.run(query, {"src": src, "tar": tar}) + edges = [record["R"] for record in result] + return edges[0] + + def gen_single_community_report(self, community: dict): + nodes = community["nodes"] + edges = community["edges"] + nodes_describe = [] + edges_describe = [] + for i in nodes: + node = self.get_node_by_id(i) + nodes_describe.append({"name": node["name"], "desc": node["description"]}) + for i in edges: + edge = self.get_edges_by_id(i[0], i[1]) + edges_describe.append( + {"source": edge["src"], "target": edge["tar"], "desc": edge["r"]} + ) + nodes_csv = "entity,description\n" + for node in nodes_describe: + nodes_csv += f"{node['name']},{node['desc']}\n" + edges_csv = "source,target,description\n" + for edge in edges_describe: + edges_csv += f"{edge['source']},{edge['target']},{edge['desc']}\n" + data = f""" + Text: + -----Entities----- + ```csv + {nodes_csv} + ``` + -----Relationships----- + ```csv + {edges_csv} + ```""" + prompt = GEN_COMMUNITY_REPORT.format(input_text=data) + report = self.llm.predict(prompt) + return report + + def generate_community_report(self): + communities_schema = self.read_community_schema() + for community_key, community in tqdm( + communities_schema.items(), desc="generating community report" + ): + community["report"] = self.gen_single_community_report(community) + with open(self.community_path, "w", encoding="utf-8") as file: + json.dump(communities_schema, file, indent=4) + print("All community report has been generated.") + + def build_local_query_context(self, query): + query_emb = self.embedding.get_emb(query) + topk_similar_entities_context = self.get_topk_similar_entities(query_emb) + topk_similar_communities_context = self.get_communities( + topk_similar_entities_context + ) + topk_similar_relations_context = self.get_relations( + topk_similar_entities_context, query + ) + topk_similar_chunks_context = self.get_chunks( + topk_similar_entities_context, query + ) + return f""" + -----Reports----- + ```csv + {topk_similar_communities_context} + ``` + -----Entities----- + ```csv + {topk_similar_entities_context} + ``` + -----Relationships----- + ```csv + {topk_similar_relations_context} + ``` + -----Sources----- + ```csv + {topk_similar_chunks_context} + ``` + """ + + def map_community_points(self, community_info, query): + points_html = self.llm.predict( + GLOBAL_MAP_POINTS.format(context_data=community_info, query=query) + ) + points = get_text_inside_tag(points_html, "point") + res = [] + for point in points: + try: + score = get_text_inside_tag(point, "score")[0] + desc = get_text_inside_tag(point, "description")[0] + res.append((desc, score)) + except: + continue + return res + + def build_global_query_context(self, query, level=1): + communities_schema = self.read_community_schema() + candidate_community = {} + points = [] + for communityid, community_info in communities_schema.items(): + if community_info["level"] < level: + candidate_community.update({communityid: community_info}) + for communityid, community_info in candidate_community.items(): + points.extend(self.map_community_points(community_info["report"], query)) + points = sorted(points, key=lambda x: x[-1], reverse=True) + return points + + def local_query(self, query): + context = self.build_local_query_context(query) + prompt = LOCAL_QUERY.format(query=query, context=context) + response = self.llm.predict(prompt) + return response + + def global_query(self, query, level=1): + context = self.build_global_query_context(query, level) + prompt = GLOBAL_QUERY.format(query=query, context=context) + response = self.llm.predict(prompt) + return response diff --git a/content/TinyGraphRAG/tinygraph/llm/__init__.py b/content/TinyGraphRAG/tinygraph/llm/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/content/TinyGraphRAG/tinygraph/llm/base.py b/content/TinyGraphRAG/tinygraph/llm/base.py new file mode 100644 index 0000000..1832c45 --- /dev/null +++ b/content/TinyGraphRAG/tinygraph/llm/base.py @@ -0,0 +1,32 @@ +from abc import ABC, abstractmethod +from typing import Any, Optional + + +class BaseLLM(ABC): + """Interface for large language models. + + Args: + model_name (str): The name of the language model. + model_params (Optional[dict[str, Any]], optional): Additional parameters passed to the model when text is sent to it. Defaults to None. + **kwargs (Any): Arguments passed to the model when for the class is initialised. Defaults to None. + """ + + def __init__( + self, + model_name: str, + model_params: Optional[dict[str, Any]] = None, + **kwargs: Any, + ): + self.model_name = model_name + self.model_params = model_params or {} + + @abstractmethod + def predict(self, input: str) -> str: + """Sends a text input to the LLM and retrieves a response. + + Args: + input (str): Text sent to the LLM + + Returns: + str: The response from the LLM. + """ diff --git a/content/TinyGraphRAG/tinygraph/llm/groq.py b/content/TinyGraphRAG/tinygraph/llm/groq.py new file mode 100644 index 0000000..34da768 --- /dev/null +++ b/content/TinyGraphRAG/tinygraph/llm/groq.py @@ -0,0 +1,32 @@ +from groq import Groq +from typing import Any, Optional +from .base import BaseLLM + + +class groqLLM(BaseLLM): + """Implementation of the BaseLLM interface using zhipuai.""" + + def __init__( + self, + model_name: str, + api_key: str, + model_params: Optional[dict[str, Any]] = None, + **kwargs: Any, + ): + super().__init__(model_name, model_params, **kwargs) + self.client = Groq(api_key=api_key) + + def predict(self, input: str) -> str: + """Sends a text input to the zhipuai model and retrieves a response. + + Args: + input (str): Text sent to the zhipuai model + + Returns: + str: The response from the zhipuai model. + """ + response = self.client.chat.completions.create( + model=self.model_name, + messages=[{"role": "user", "content": input}], + ) + return response.choices[0].message.content diff --git a/content/TinyGraphRAG/tinygraph/llm/zhipu.py b/content/TinyGraphRAG/tinygraph/llm/zhipu.py new file mode 100644 index 0000000..1031e34 --- /dev/null +++ b/content/TinyGraphRAG/tinygraph/llm/zhipu.py @@ -0,0 +1,32 @@ +from zhipuai import ZhipuAI +from typing import Any, Optional +from .base import BaseLLM + + +class zhipuLLM(BaseLLM): + """Implementation of the BaseLLM interface using zhipuai.""" + + def __init__( + self, + model_name: str, + api_key: str, + model_params: Optional[dict[str, Any]] = None, + **kwargs: Any, + ): + super().__init__(model_name, model_params, **kwargs) + self.client = ZhipuAI(api_key=api_key) + + def predict(self, input: str) -> str: + """Sends a text input to the zhipuai model and retrieves a response. + + Args: + input (str): Text sent to the zhipuai model + + Returns: + str: The response from the zhipuai model. + """ + response = self.client.chat.completions.create( + model=self.model_name, + messages=[{"role": "user", "content": input}], + ) + return response.choices[0].message.content diff --git a/content/TinyGraphRAG/tinygraph/prompt.py b/content/TinyGraphRAG/tinygraph/prompt.py new file mode 100644 index 0000000..61164f7 --- /dev/null +++ b/content/TinyGraphRAG/tinygraph/prompt.py @@ -0,0 +1,383 @@ +GEN_NODES = """ +## Goal +Please identify and extract triplet information from the provided article, focusing only on entities and relationships related to significant knowledge points. +Each triplet should be in the form of (Subject, Predicate, Object). +Follow these guidelines: + +1. **Subject:** Concepts in Bayesian Optimization +2. **Predicate:** The action or relationship that links the subject to the object. +3. **Object:** Concepts in Bayesian Optimization that is affected by or related to the action of the subject. + +## Example +For the sentence "Gaussian Processes are used to model the objective function in Bayesian Optimization" the triplet would be: + +Gaussian Processesare used to model the objective function inBayesian Optimization + +For the sentence "John read a book on the weekend," which is not related to any knowledge points, no triplet should be extracted. + +## Instructions +1. Read through the article carefully. +2. Think step by step. Try to find some useful knowledge points from the article. You need to reorganize the content of the sentence into corresponding knowledge points. +3. Identify key sentences that contain relevant triplet information related to significant knowledge points. +4. Extract and format the triplets as per the given example, excluding any information that is not relevant to significant knowledge points. + +## Output Format +For each identified triplet, provide: +[Entity]The action or relationshipThe entity + +## Article + +{text} + +## Your response +""" + +GET_ENTITY = """ +## Goal + +You are an experienced machine learning teacher. +You need to identify the key concepts related to machine learning that the article requires students to master. For each concept, provide a brief description that explains its relevance and importance in the context of the article. + +## Example + +article: +"In the latest study, we explored the potential of using machine learning algorithms for disease prediction. We used support vector machines (SVM) and random forest algorithms to analyze medical data. The results showed that these models performed well in predicting disease risk through feature selection and cross-validation. In particular, the random forest model showed better performance in dealing with overfitting problems. In addition, we discussed the application of deep learning in medical image analysis." + +response: + + Support Vector Machine (SVM) + A supervised learning model used for classification and regression tasks, particularly effective in high-dimensional spaces. + + + Random Forest Algorithm + An ensemble learning method that builds multiple decision trees and merges them together to get a more accurate and stable prediction, often used to reduce overfitting. + + + Feature Selection + The process of selecting a subset of relevant features for use in model construction, crucial for improving model performance and reducing complexity. + + + Overfitting + A common issue where a model learns the details and noise in the training data to the extent that it negatively impacts the model's performance on new data. + + + Deep Learning + A subset of machine learning that uses neural networks with many layers to model complex patterns in large datasets, often applied in image and speech recognition tasks. + + +## Format + +Wrap each concept in the HTML tag , and include the name of the concept in the tag and its description in the tag. + +## Article + +{text} + +## Your response +""" + + +ENTITY_DISAMBIGUATION = """ +## Goal +Given multiple entities with the same name, determine if they can be merged into a single entity. If merging is possible, provide the transformation from entity id to entity id. + +## Guidelines +1. **Entities:** A list of entities with the same name. +2. **Merge:** Determine if the entities can be merged into a single entity. +3. **Transformation:** If merging is possible, provide the transformation from entity id to entity id. + +## Example +1. Entities: + [ + {"name": "Entity A", "entity id": "entity-1"}, + {"name": "Entity A", "entity id": "entity-2"}, + {"name": "Entity A", "entity id": "entity-3"} + ] + +Your response should be: + +{"entity-2": "entity-1", "entity-3": "entity-1"} + + +2. Entities: + [ + {"name": "Entity B", "entity id": "entity-4"}, + {"name": "Entity C", "entity id": "entity-5"}, + {"name": "Entity B", "entity id": "entity-6"} + ] + +Your response should be: + +None + +## Output Format +Provide the following information: +- Transformation: A dictionary mapping entity ids to the final entity id after merging. + +## Given Entities +{entities} + +## Your response +""" + +GET_TRIPLETS = """ +## Goal +Identify and extract all the relationships between the given concepts from the provided text. +Identify as many relationships between the concepts as possible. +The relationship in the triple should accurately reflect the interaction or connection between the two concepts. + +## Guidelines: +1. **Subject:** The first entity from the given entities. +2. **Predicate:** The action or relationship linking the subject to the object. +3. **Object:** The second entity from the given entities. + +## Example: +1. Article : + "Gaussian Processes are used to model the objective function in Bayesian Optimization" + Given entities: + [{{"name": "Gaussian Processes", "entity id": "entity-1"}}, {{"name": "Bayesian Optimization", "entity id": "entity-2"}}] + Output: + Gaussian Processesentity-1are used to model the objective function inBayesian Optimizationentity-2 + +2. Article : + "Hydrogen is a colorless, odorless, non-toxic gas and is the lightest and most abundant element in the universe. Oxygen is a gas that supports combustion and is widely present in the Earth's atmosphere. Water is a compound made up of hydrogen and oxygen, with the chemical formula H2O." + Given entities: + [{{"name": "Hydrogen", "entity id": "entity-3"}}, {{"name": "Oxygen", "entity id": "entity-4"}}, {{"name": "Water", "entity id": "entity-5"}}] + Output: + Hydrogenentity-3is a component ofWaterentity-5 +3. Article : + "John read a book on the weekend" + Given entities: + [] + Output: + None + +## Format: +For each identified triplet, provide: +**the entity should just from "Given Entities"** +[Entity][Entity ID][The action or relationship][Entity][Entity ID] + +## Given Entities: +{entity} + +### Article: +{text} + +## Additional Instructions: +- Before giving your response, you should analyze and think about it sentence by sentence. +- Both the subject and object must be selected from the given entities and cannot change their content. +- If no relevant triplet involving both entities is found, no triplet should be extracted. +- If there are similar concepts, please rewrite them into a form that suits our requirements. + +## Your response: +""" + +TEST_PROMPT = """ +## Foundation of students +{state} +## Gole +You will help students solve question through multiple rounds of dialogue. +Please follow the steps below to help students solve the question: +1. Explain the basic knowledge and principles behind the question and make sure the other party understands these basic concepts. +2. Don't give a complete answer directly, but guide the student to think about the key steps of the question. +3. After guiding the student to think, let them try to solve the question by themselves. Give appropriate hints and feedback to help them correct their mistakes and further improve their solutions. +4. Return to TERMINATE after solving the problem +""" + +GEN_COMMUNITY_REPORT = """ +## Role +You are an AI assistant that helps a human analyst to perform general information discovery. +Information discovery is the process of identifying and assessing relevant information associated with certain entities (e.g., organizations and individuals) within a network. + +## Goal +Write a comprehensive report of a community. +Given a list of entities that belong to the community as well as their relationships and optional associated claims. The report will be used to inform decision-makers about information associated with the community and their potential impact. +The content of this report includes an overview of the community's key entities, their legal compliance, technical capabilities, reputation, and noteworthy claims. + +## Report Structure + +The report should include the following sections: + +- TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title. +- SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant information associated with its entities. +- DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive. + +Return output as a well-formed JSON-formatted string with the following format: +{{ +"title": , +"summary": , +"findings": [ +{{ +"summary":, +"explanation": +}}, +{{ +"summary":, +"explanation": +}} +... +] +}} + +## Grounding Rules +Do not include information where the supporting evidence for it is not provided. + +## Example Input +----------- +Text: +``` +Entities: +```csv +entity,description +VERDANT OASIS PLAZA,Verdant Oasis Plaza is the location of the Unity March +HARMONY ASSEMBLY,Harmony Assembly is an organization that is holding a march at Verdant Oasis Plaza +``` +Relationships: +```csv +source,target,description +VERDANT OASIS PLAZA,UNITY MARCH,Verdant Oasis Plaza is the location of the Unity March +VERDANT OASIS PLAZA,HARMONY ASSEMBLY,Harmony Assembly is holding a march at Verdant Oasis Plaza +VERDANT OASIS PLAZA,UNITY MARCH,The Unity March is taking place at Verdant Oasis Plaza +VERDANT OASIS PLAZA,TRIBUNE SPOTLIGHT,Tribune Spotlight is reporting on the Unity march taking place at Verdant Oasis Plaza +VERDANT OASIS PLAZA,BAILEY ASADI,Bailey Asadi is speaking at Verdant Oasis Plaza about the march +HARMONY ASSEMBLY,UNITY MARCH,Harmony Assembly is organizing the Unity March +``` +``` +Output: +{{ +"title": "Verdant Oasis Plaza and Unity March", +"summary": "The community revolves around the Verdant Oasis Plaza, which is the location of the Unity March. The plaza has relationships with the Harmony Assembly, Unity March, and Tribune Spotlight, all of which are associated with the march event.", +"findings": [ +{{ +"summary": "Verdant Oasis Plaza as the central location", +"explanation": "Verdant Oasis Plaza is the central entity in this community, serving as the location for the Unity March. This plaza is the common link between all other entities, suggesting its significance in the community. The plaza's association with the march could potentially lead to issues such as public disorder or conflict, depending on the nature of the march and the reactions it provokes." +}}, +{{ +"summary": "Harmony Assembly's role in the community", +"explanation": "Harmony Assembly is another key entity in this community, being the organizer of the march at Verdant Oasis Plaza. The nature of Harmony Assembly and its march could be a potential source of threat, depending on their objectives and the reactions they provoke. The relationship between Harmony Assembly and the plaza is crucial in understanding the dynamics of this community." +}}, +{{ +"summary": "Unity March as a significant event", +"explanation": "The Unity March is a significant event taking place at Verdant Oasis Plaza. This event is a key factor in the community's dynamics and could be a potential source of threat, depending on the nature of the march and the reactions it provokes. The relationship between the march and the plaza is crucial in understanding the dynamics of this community." +}}, +{{ +"summary": "Role of Tribune Spotlight", +"explanation": "Tribune Spotlight is reporting on the Unity March taking place in Verdant Oasis Plaza. This suggests that the event has attracted media attention, which could amplify its impact on the community. The role of Tribune Spotlight could be significant in shaping public perception of the event and the entities involved." +}} +] +}} + +## Real Data +Use the following text for your answer. Do not make anything up in your answer. + +Text: +``` +{input_text} +``` + +The report should include the following sections: + +- TITLE: community's name that represents its key entities - title should be short but specific. When possible, include representative named entities in the title. +- SUMMARY: An executive summary of the community's overall structure, how its entities are related to each other, and significant information associated with its entities. +- DETAILED FINDINGS: A list of 5-10 key insights about the community. Each insight should have a short summary followed by multiple paragraphs of explanatory text grounded according to the grounding rules below. Be comprehensive. + +Return output as a well-formed JSON-formatted string with the following format: +{{ +"title": , +"summary": , +"rating": , +"rating_explanation": , +"findings": [ +{{ +"summary":, +"explanation": +}}, +{{ +"summary":, +"explanation": +}} +... +] +}} + +## Grounding Rules +Do not include information where the supporting evidence for it is not provided. + +Output: +""" + +GLOBAL_MAP_POINTS = """ +You are a helpful assistant responding to questions about data in the tables provided. + + +---Goal--- + +Generate a response consisting of a list of key points that responds to the user's question, summarizing all relevant information in the input data tables. + +You should use the data provided in the data tables below as the primary context for generating the response. +If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up. + +Each key point in the response should have the following element: +- Description: A comprehensive description of the point. +- Importance Score: An integer score between 0-100 that indicates how important the point is in answering the user's question. An 'I don't know' type of response should have a score of 0. + +The response should be HTML formatted as follows: + + +"Description of point 1..."score_value +"Description of point 2..."score_value + + +The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". +Do not include information where the supporting evidence for it is not provided. + + +---Data tables--- + +{context_data} + +---User query--- + +{query} + +---Goal--- + +Generate a response consisting of a list of key points that responds to the user's question, summarizing all relevant information in the input data tables. + +You should use the data provided in the data tables below as the primary context for generating the response. +If you don't know the answer or if the input data tables do not contain sufficient information to provide an answer, just say so. Do not make anything up. + +Each key point in the response should have the following element: +- Description: A comprehensive description of the point. +- Importance Score: An integer score between 0-100 that indicates how important the point is in answering the user's question. An 'I don't know' type of response should have a score of 0. + +The response shall preserve the original meaning and use of modal verbs such as "shall", "may" or "will". +Do not include information where the supporting evidence for it is not provided. + +The response should be HTML formatted as follows: + +"Description of point 1..."score_value +"Description of point 2..."score_value + + +""" + +LOCAL_QUERY = """ +## User Query +{query} +## Context +{context} +## Task +Based on given context, please provide a response to the user query. +## Your Response +""" + +GLOBAL_QUERY = """ +## User Query +{query} +## Context +{context} +## Task +Based on given context, please provide a response to the user query. +## Your Response +""" diff --git a/content/TinyGraphRAG/tinygraph/utils.py b/content/TinyGraphRAG/tinygraph/utils.py new file mode 100644 index 0000000..decaf58 --- /dev/null +++ b/content/TinyGraphRAG/tinygraph/utils.py @@ -0,0 +1,55 @@ +import re +import numpy as np +from typing import List, Tuple +from hashlib import md5 +import json +import os + + +def get_text_inside_tag(html_string: str, tag: str): + # html_string 为待解析文本,tag为查找标签 + pattern = f"<{tag}>(.*?)<\/{tag}>" + try: + result = re.findall(pattern, html_string, re.DOTALL) + return result + except SyntaxError as e: + raise ("Json Decode Error: {error}".format(error=e)) + + +def read_json_file(file_path): + try: + with open(file_path, "r", encoding="utf-8") as file: + return json.load(file) + except: + return {} + + +def write_json_file(data, file_path): + with open(file_path, "w", encoding="utf-8") as file: + json.dump(data, file, indent=4, ensure_ascii=False) + + +def compute_mdhash_id(content, prefix: str = ""): + return prefix + md5(content.encode()).hexdigest() + + +def save_triplets_to_txt(triplets, file_path): + with open(file_path, "a", encoding="utf-8") as file: + file.write(f"{triplets[0]},{triplets[1]},{triplets[2]}\n") + + +def cosine_similarity(vector1: List[float], vector2: List[float]) -> float: + """ + calculate cosine similarity between two vectors + """ + dot_product = np.dot(vector1, vector2) + magnitude = np.linalg.norm(vector1) * np.linalg.norm(vector2) + if not magnitude: + return 0 + return dot_product / magnitude + + +def create_file_if_not_exists(file_path: str): + if not os.path.exists(file_path): + with open(file_path, "w") as f: + f.write("")