diff --git a/PathRAG/llm.py b/PathRAG/llm.py index 80b5580..959e92f 100644 --- a/PathRAG/llm.py +++ b/PathRAG/llm.py @@ -12,6 +12,8 @@ import ollama import torch import time +import modelscope as ms +from vllm import LLM from openai import ( AsyncOpenAI, APIConnectionError, @@ -36,6 +38,7 @@ ) import sys +device = "cuda" if torch.cuda.is_available() else "cpu" if sys.version_info < (3, 9): from typing import AsyncIterator @@ -223,16 +226,32 @@ async def bedrock_complete_if_cache( @lru_cache(maxsize=1) def initialize_hf_model(model_name): - hf_tokenizer = AutoTokenizer.from_pretrained( - model_name, device_map="auto", trust_remote_code=True + # 是否使用 GPU,如果无 GPU 则回退到 CPU + use_gpu = torch.cuda.is_available() + dtype = torch.bfloat16 if use_gpu else torch.float32 + device_map = "auto" if use_gpu else "cpu" + + # 加载 tokenizer + tokenizer = AutoTokenizer.from_pretrained( + model_name, + trust_remote_code=True, + use_fast=False # 有些模型如 Qwen 不支持 fast tokenizer ) - hf_model = AutoModelForCausalLM.from_pretrained( - model_name, device_map="auto", trust_remote_code=True + + # 加载模型 + model = AutoModelForCausalLM.from_pretrained( + model_name, + trust_remote_code=True, + torch_dtype=dtype, + device_map=device_map, + low_cpu_mem_usage=True ) - if hf_tokenizer.pad_token is None: - hf_tokenizer.pad_token = hf_tokenizer.eos_token - return hf_model, hf_tokenizer + # 设置 pad_token + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + return model, tokenizer @retry( @@ -290,15 +309,14 @@ async def hf_model_if_cache( input_ids = hf_tokenizer( input_prompt, return_tensors="pt", padding=True, truncation=True - ).to("cuda") + )#修改.to("cuda") inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()} output = hf_model.generate( - **input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True + **input_ids, max_new_tokens=512, num_return_sequences=1 ) response_text = hf_tokenizer.decode( output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True ) - return response_text @@ -993,6 +1011,37 @@ async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray: else: return embeddings.detach().cpu().numpy() +async def ms_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray: + device = next(embed_model.parameters()).device + input_ids = tokenizer( + texts, return_tensors="pt", padding=True, truncation=True + ).input_ids.to(device) + with torch.no_grad(): + outputs = embed_model(input_ids) + embeddings = outputs.last_hidden_state.mean(dim=1) + if embeddings.dtype == torch.bfloat16: + return embeddings.detach().to(torch.float32).cpu().numpy() + else: + return embeddings.detach().cpu().numpy() + +async def local_embedding(texts: list[str], tokenizer=None, embed_model=None) -> np.ndarray: + if tokenizer is None or embed_model is None: + raise ValueError("Tokenizer and model must be provided") + device = next(embed_model.parameters()).device + encoded = tokenizer( + texts, + padding=True, + truncation=True, + return_tensors="pt" + ).input_ids.to(device) + with torch.no_grad(): + outputs = embed_model(encoded) + embeddings = outputs.last_hidden_state.mean(dim=1) + if embeddings.dtype == torch.bfloat16: + return embeddings.detach().to(torch.float32).cpu().numpy() + else: + return embeddings.detach().cpu().numpy() + async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray: """ @@ -1013,6 +1062,313 @@ async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray: return data["embeddings"] + +@lru_cache(maxsize=1) +def initialize_ms_model(model_name: str): + """加载 ModelScope 聊天模型并缓存(底层实现)""" + tokenizer = ms.AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) + model = ms.AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, + trust_remote_code=True, + low_cpu_mem_usage=True + ).to(device).eval() + return tokenizer, model + +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), +) +async def ms_model_if_cache( + model: str, + prompt: str, + system_prompt: str = None, + history_messages: list = [], + **kwargs, +) -> str: + tokenizer, model = initialize_ms_model(model) + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + kwargs.pop("hashing_kv", None) + input_prompt = "" + try: + input_prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + except Exception: + try: + ori_message = copy.deepcopy(messages) + if messages[0]["role"] == "system": + messages[1]["content"] = ( + "" + + messages[0]["content"] + + "\n" + + messages[1]["content"] + ) + messages = messages[1:] + input_prompt = tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + except Exception: + len_message = len(ori_message) + for msgid in range(len_message): + input_prompt = ( + input_prompt + + "<" + + ori_message[msgid]["role"] + + ">" + + ori_message[msgid]["content"] + + "\n" + ) + + input_ids = tokenizer( + input_prompt, return_tensors="pt", padding=True, truncation=True + ).to(device) + inputs = {k: v.to(model.device) for k, v in input_ids.items()} + output = model.generate( + **input_ids, max_new_tokens=512, num_return_sequences=1 + ) + response_text = tokenizer.decode( + output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True + ) + return response_text + + + + +async def ms_model_complete( + prompt: str, + system_prompt: str = None, + history_messages: list = [], + keyword_extraction: bool = False, + **kwargs, +) -> str: + keyword_extraction = kwargs.pop("keyword_extraction", None) + model_name = kwargs["hashing_kv"].global_config["llm_model_name"] + result = await ms_model_if_cache( + model_name, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + # print("ms_res",result) + if keyword_extraction: + return locate_json_string_body_from_string(result) + + return result + + + +@lru_cache(maxsize=1) +def initialize_local_model(model_path: str): + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained( + model_path, + trust_remote_code=True, + torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, + ).to(device).eval() + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + return model, tokenizer + + +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), +) +async def local_model_if_cache( + model, + prompt, + system_prompt=None, + history_messages=[], + **kwargs, +) -> str: + model_name = model + local_model, local_tokenizer = initialize_local_model(model_name) + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + kwargs.pop("hashing_kv", None) + input_prompt = "" + try: + input_prompt = local_tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + except Exception: + try: + ori_message = copy.deepcopy(messages) + if messages[0]["role"] == "system": + messages[1]["content"] = ( + "" + + messages[0]["content"] + + "\n" + + messages[1]["content"] + ) + messages = messages[1:] + input_prompt = local_tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + except Exception: + len_message = len(ori_message) + for msgid in range(len_message): + input_prompt = ( + input_prompt + + "<" + + ori_message[msgid]["role"] + + ">" + + ori_message[msgid]["content"] + + "\n" + ) + input_ids = local_tokenizer( + input_prompt, return_tensors="pt", padding=True, truncation=True + ).to(device) + inputs = {k: v.to(local_model.device) for k, v in input_ids.items()} + output = local_model.generate( + **input_ids, max_new_tokens=512, num_return_sequences=1 + ) + response_text = local_tokenizer.decode( + output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True + ) + # print("response_text",response_text)#修改 + return response_text + +async def local_model_complete( + prompt: str, + system_prompt: str = None, + history_messages: list = [], + keyword_extraction: bool = False, + **kwargs, +) -> str: + keyword_extraction = kwargs.pop("keyword_extraction", None) + model_path = kwargs["hashing_kv"].global_config["llm_model_name"] + result = await local_model_if_cache( + model_path, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + if keyword_extraction: + result= locate_json_string_body_from_string(result) + return result + + +@lru_cache(maxsize=1) +def initialize_vllm_model(model_name: str): + """加载 vLLM 聊天模型并缓存""" + model = LLM(model_name, device=device, max_model_len=8192)#修改 + return model + + +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10), + retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)), +) +async def vllm_model_if_cache( + model: str, + prompt: str, + system_prompt: str = None, + history_messages: list = [], + **kwargs, +) -> str: + vllm_model = initialize_vllm_model(model) + + # 构建多轮对话格式 + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.extend(history_messages) + messages.append({"role": "user", "content": prompt}) + kwargs.pop("hashing_kv", None) + input_prompt = "" + try: + result = vllm_model.chat(messages) + except Exception: + try: + ori_message = copy.deepcopy(messages) + if messages[0]["role"] == "system": + messages[1]["content"] = ( + "" + + messages[0]["content"] + + "\n" + + messages[1]["content"] + ) + messages = messages[1:] + result = vllm_model.chat(messages) + except Exception: + len_message = len(ori_message) + for msgid in range(len_message): + result = ( + result + + "<" + + ori_message[msgid]["role"] + + ">" + + ori_message[msgid]["content"] + + "\n" + ) + return result[0].outputs[0].text + + + +async def vllm_model_complete( + prompt: str, + system_prompt: str = None, + history_messages: list = [], + keyword_extraction: bool = False, + **kwargs, +) -> str: + keyword_extraction = kwargs.pop("keyword_extraction", None) + model_name = kwargs["hashing_kv"].global_config["llm_model_name"] + result = await vllm_model_if_cache( + model_name, + prompt, + system_prompt=system_prompt, + history_messages=history_messages, + **kwargs, + ) + if keyword_extraction: + return locate_json_string_body_from_string(result) + return result + + +async def vllm_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray: + device = next(embed_model.parameters()).device + input_ids = tokenizer( + texts, return_tensors="pt", padding=True, truncation=True + ).input_ids.to(device) + with torch.no_grad(): + outputs = embed_model(input_ids) + embeddings = outputs.last_hidden_state.mean(dim=1) + if embeddings.dtype == torch.bfloat16: + return embeddings.detach().to(torch.float32).cpu().numpy() + else: + return embeddings.detach().cpu().numpy() + + + + + + + + + + class Model(BaseModel): """ This is a Pydantic model class named 'Model' that is used to define a custom language model. @@ -1099,6 +1455,5 @@ async def llm_model_func( async def main(): result = await gpt_4o_mini_complete("How are you?") - print(result) asyncio.run(main()) diff --git a/README.md b/README.md index e7ad841..f7c246f 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,30 @@ -The code for the paper **"PathRAG: Pruning Graph-based Retrieval Augmented Generation with Relational Paths"**. +
+

+ PathRAG: Pruning Graph-based Retrieval Augmented Generation with Relational Paths +

+
+
+
+

+ + +

+
+
+image + ## Install +Python Version: Python 3.10.18 ```bash cd PathRAG pip install -e . ``` ## Quick Start -* You can quickly experience this project in the `v1_test.py` file. -* Set OpenAI API key in environment if using OpenAI models: `api_key="sk-...".` in the `v1_test.py` and `llm.py` file -* Prepare your retrieval document "text.txt". -* Use the following Python snippet in the "v1_text.py" file to initialize PathRAG and perform queries. +- use OpenAI API key +- You can quickly experience this project in the `v1_test.py` file. +- Set OpenAI API key in environment if using OpenAI models: `api_key="sk-...".` in the `v1_test.py` and `llm.py` file +- Prepare your retrieval document "text.txt". +- Use the following Python snippet in the `v1_test.py` file to initialize PathRAG and perform queries. ```python import os @@ -37,9 +53,53 @@ with open(data_file) as f: print(rag.query(question, param=QueryParam(mode="hybrid"))) ``` +## Quick Start with models from different sources +- You can use the model from huggingface, ollama, modelscope, local and vllm +- You can quickly experience this project in the `rag_test.py` file +- Select your model source—— hf / vllm / ollama / ms / local +- Prepare your llm_model,embedding_model and retrieval document "your data file" . +- Use the following Python snippet in the `rag_test.py` file to use models from different sources +- Detailed examples can be referred to`不同模型样例.txt` +```Python +import os +import asyncio +import torch +from PathRAG.RAGRunner import RAGRunner +if __name__ == "__main__": + backend = "your_model_source" # hf / vllm / ollama / ms / local可选 + working_dir = f"your_working_dir" # 工作目录 + llm_model_name = "Qwen/Qwen3-0.6B" # 聊天模型名称或路径 + embedding_model_name = "iic/nlp_corom_sentence-embedding_english-base" # 编码模型名称或路径 + # ollama 额外参数 + llm_model_kwargs = { + "host": "http://localhost:11434", + "options": {"num_ctx": 8192}, + "timeout": 300, + } if backend == "ollama" else {} + + runner = RAGRunner( + backend=backend, + working_dir=working_dir, + llm_model_name=llm_model_name, + embedding_model_name=embedding_model_name, + llm_model_max_token_size=8192, + llm_model_kwargs=llm_model_kwargs, + embedding_dim=768, + embedding_max_token_size=5000, + ) + data_file = "your_data_file" + question = "your_question" + with open(data_file, "r", encoding="utf-8") as f: + runner.insert_text(f.read()) + answer = runner.query(question, mode="hybrid") + print("问:", question) + print("答:", answer) +``` + ## Parameter modification You can adjust the relevant parameters in the `base.py` and `operate.py` files. - +- You can change the hyperparameter `top_k` in `base.py`, where `top_k` represents the number of nodes retrieved +- You can change the hyperparameters `alpha` and `threshold` in `operate.py`, where `alpha` represents the decay rate of information propagation along the edges, and `threshold` is the pruning threshold. ## Batch Insert ```python import os @@ -52,13 +112,75 @@ for file_name in txt_files: rag.insert(file.read()) ``` -## Cite +## Evaluation + +### Dataset +The dataset used in PathRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain). +### Eval Metrics +
+ Prompt + +```python +You will evaluate two answers to the same question based on five criteria: **Comprehensiveness**, **Diversity**, **logicality**, **Coherence**, **Relevance**. + +- **Comprehensiveness**: How much detail does the answer provide to cover all aspects and details of the question? +- **Diversity**: How varied and rich is the answer in providing different perspectives and insights on the question? +- **logicality**: How logically does the answer respond to all parts of the question? +- **Coherence**: How well does the answer maintain internal logical connections between its parts, ensuring a smooth and consistent structure? +- **Relevance**: How relevant is the answer to the question, staying focused and addressing the intended topic or issue? + +For each criterion, choose the better answer (either Answer 1 or Answer 2) and explain why. + +Here is the question: +{query} + +Here are the two answers: +**Answer 1:** +{answer1} + +**Answer 2:** +{answer2} + +Evaluate both answers using the five criteria listed above and provide detailed explanations for each criterion. + +Output your evaluation in the following JSON format: +{{ + "Comprehensiveness": {{ + "Winner": "[Answer 1 or Answer 2]", + "Explanation": "[Provide explanation here]" + }}, + "Diversity": {{ + "Winner": "[Answer 1 or Answer 2]", + "Explanation": "[Provide explanation here]" + }}, + "logicality": {{ + "Winner": "[Answer 1 or Answer 2]", + "Explanation": "[Provide explanation here]" + }}, + "Coherence": {{ + "Winner": "[Answer 1 or Answer 2]", + "Explanation": "[Provide explanation here]" + }}, + "Relevance": {{ + "Winner": "[Answer 1 or Answer 2]", + "Explanation": "[Provide explanation here]" + }} +}} +``` + +
+ +image + +## Contribution +Boyu Chen, Zirui Guo, Zidan Yang, Junfei Bao +## Citation Please cite our paper if you use this code in your own work: ```python @article{chen2025pathrag, - title={PathRAG: Pruning Graph-based Retrieval Augmented Generation with Relational Paths}, - author={Chen, Boyu and Guo, Zirui and Yang, Zidan and Chen, Yuluo and Chen, Junze and Liu, Zhenghao and Shi, Chuan and Yang, Cheng}, - journal={arXiv preprint arXiv:2502.14902}, - year={2025} + title={PathRAG: Pruning Graph-based Retrieval Augmented Generation with Relational Paths}, + author={Boyu Chen and Zirui Guo and Zidan Yang and Yuluo Chen and Junze Chen and Zhenghao Liu and Chuan Shi and Cheng Yang}, + journal={arXiv preprint arXiv:2502.14902}, + year={2025} } ```