diff --git a/PathRAG/llm.py b/PathRAG/llm.py
index 80b5580..959e92f 100644
--- a/PathRAG/llm.py
+++ b/PathRAG/llm.py
@@ -12,6 +12,8 @@
import ollama
import torch
import time
+import modelscope as ms
+from vllm import LLM
from openai import (
AsyncOpenAI,
APIConnectionError,
@@ -36,6 +38,7 @@
)
import sys
+device = "cuda" if torch.cuda.is_available() else "cpu"
if sys.version_info < (3, 9):
from typing import AsyncIterator
@@ -223,16 +226,32 @@ async def bedrock_complete_if_cache(
@lru_cache(maxsize=1)
def initialize_hf_model(model_name):
- hf_tokenizer = AutoTokenizer.from_pretrained(
- model_name, device_map="auto", trust_remote_code=True
+ # 是否使用 GPU,如果无 GPU 则回退到 CPU
+ use_gpu = torch.cuda.is_available()
+ dtype = torch.bfloat16 if use_gpu else torch.float32
+ device_map = "auto" if use_gpu else "cpu"
+
+ # 加载 tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(
+ model_name,
+ trust_remote_code=True,
+ use_fast=False # 有些模型如 Qwen 不支持 fast tokenizer
)
- hf_model = AutoModelForCausalLM.from_pretrained(
- model_name, device_map="auto", trust_remote_code=True
+
+ # 加载模型
+ model = AutoModelForCausalLM.from_pretrained(
+ model_name,
+ trust_remote_code=True,
+ torch_dtype=dtype,
+ device_map=device_map,
+ low_cpu_mem_usage=True
)
- if hf_tokenizer.pad_token is None:
- hf_tokenizer.pad_token = hf_tokenizer.eos_token
- return hf_model, hf_tokenizer
+ # 设置 pad_token
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+
+ return model, tokenizer
@retry(
@@ -290,15 +309,14 @@ async def hf_model_if_cache(
input_ids = hf_tokenizer(
input_prompt, return_tensors="pt", padding=True, truncation=True
- ).to("cuda")
+ )#修改.to("cuda")
inputs = {k: v.to(hf_model.device) for k, v in input_ids.items()}
output = hf_model.generate(
- **input_ids, max_new_tokens=512, num_return_sequences=1, early_stopping=True
+ **input_ids, max_new_tokens=512, num_return_sequences=1
)
response_text = hf_tokenizer.decode(
output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
)
-
return response_text
@@ -993,6 +1011,37 @@ async def hf_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
else:
return embeddings.detach().cpu().numpy()
+async def ms_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
+ device = next(embed_model.parameters()).device
+ input_ids = tokenizer(
+ texts, return_tensors="pt", padding=True, truncation=True
+ ).input_ids.to(device)
+ with torch.no_grad():
+ outputs = embed_model(input_ids)
+ embeddings = outputs.last_hidden_state.mean(dim=1)
+ if embeddings.dtype == torch.bfloat16:
+ return embeddings.detach().to(torch.float32).cpu().numpy()
+ else:
+ return embeddings.detach().cpu().numpy()
+
+async def local_embedding(texts: list[str], tokenizer=None, embed_model=None) -> np.ndarray:
+ if tokenizer is None or embed_model is None:
+ raise ValueError("Tokenizer and model must be provided")
+ device = next(embed_model.parameters()).device
+ encoded = tokenizer(
+ texts,
+ padding=True,
+ truncation=True,
+ return_tensors="pt"
+ ).input_ids.to(device)
+ with torch.no_grad():
+ outputs = embed_model(encoded)
+ embeddings = outputs.last_hidden_state.mean(dim=1)
+ if embeddings.dtype == torch.bfloat16:
+ return embeddings.detach().to(torch.float32).cpu().numpy()
+ else:
+ return embeddings.detach().cpu().numpy()
+
async def ollama_embedding(texts: list[str], embed_model, **kwargs) -> np.ndarray:
"""
@@ -1013,6 +1062,313 @@ async def ollama_embed(texts: list[str], embed_model, **kwargs) -> np.ndarray:
return data["embeddings"]
+
+@lru_cache(maxsize=1)
+def initialize_ms_model(model_name: str):
+ """加载 ModelScope 聊天模型并缓存(底层实现)"""
+ tokenizer = ms.AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
+ model = ms.AutoModelForCausalLM.from_pretrained(
+ model_name,
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
+ trust_remote_code=True,
+ low_cpu_mem_usage=True
+ ).to(device).eval()
+ return tokenizer, model
+
+@retry(
+ stop=stop_after_attempt(3),
+ wait=wait_exponential(multiplier=1, min=4, max=10),
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
+)
+async def ms_model_if_cache(
+ model: str,
+ prompt: str,
+ system_prompt: str = None,
+ history_messages: list = [],
+ **kwargs,
+) -> str:
+ tokenizer, model = initialize_ms_model(model)
+ messages = []
+ if system_prompt:
+ messages.append({"role": "system", "content": system_prompt})
+ messages.extend(history_messages)
+ messages.append({"role": "user", "content": prompt})
+ kwargs.pop("hashing_kv", None)
+ input_prompt = ""
+ try:
+ input_prompt = tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+ except Exception:
+ try:
+ ori_message = copy.deepcopy(messages)
+ if messages[0]["role"] == "system":
+ messages[1]["content"] = (
+ ""
+ + messages[0]["content"]
+ + "\n"
+ + messages[1]["content"]
+ )
+ messages = messages[1:]
+ input_prompt = tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+ except Exception:
+ len_message = len(ori_message)
+ for msgid in range(len_message):
+ input_prompt = (
+ input_prompt
+ + "<"
+ + ori_message[msgid]["role"]
+ + ">"
+ + ori_message[msgid]["content"]
+ + ""
+ + ori_message[msgid]["role"]
+ + ">\n"
+ )
+
+ input_ids = tokenizer(
+ input_prompt, return_tensors="pt", padding=True, truncation=True
+ ).to(device)
+ inputs = {k: v.to(model.device) for k, v in input_ids.items()}
+ output = model.generate(
+ **input_ids, max_new_tokens=512, num_return_sequences=1
+ )
+ response_text = tokenizer.decode(
+ output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
+ )
+ return response_text
+
+
+
+
+async def ms_model_complete(
+ prompt: str,
+ system_prompt: str = None,
+ history_messages: list = [],
+ keyword_extraction: bool = False,
+ **kwargs,
+) -> str:
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
+ model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
+ result = await ms_model_if_cache(
+ model_name,
+ prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ **kwargs,
+ )
+ # print("ms_res",result)
+ if keyword_extraction:
+ return locate_json_string_body_from_string(result)
+
+ return result
+
+
+
+@lru_cache(maxsize=1)
+def initialize_local_model(model_path: str):
+ tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+ model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ trust_remote_code=True,
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
+ ).to(device).eval()
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+
+ return model, tokenizer
+
+
+@retry(
+ stop=stop_after_attempt(3),
+ wait=wait_exponential(multiplier=1, min=4, max=10),
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
+)
+async def local_model_if_cache(
+ model,
+ prompt,
+ system_prompt=None,
+ history_messages=[],
+ **kwargs,
+) -> str:
+ model_name = model
+ local_model, local_tokenizer = initialize_local_model(model_name)
+ messages = []
+ if system_prompt:
+ messages.append({"role": "system", "content": system_prompt})
+ messages.extend(history_messages)
+ messages.append({"role": "user", "content": prompt})
+ kwargs.pop("hashing_kv", None)
+ input_prompt = ""
+ try:
+ input_prompt = local_tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+ except Exception:
+ try:
+ ori_message = copy.deepcopy(messages)
+ if messages[0]["role"] == "system":
+ messages[1]["content"] = (
+ ""
+ + messages[0]["content"]
+ + "\n"
+ + messages[1]["content"]
+ )
+ messages = messages[1:]
+ input_prompt = local_tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+ except Exception:
+ len_message = len(ori_message)
+ for msgid in range(len_message):
+ input_prompt = (
+ input_prompt
+ + "<"
+ + ori_message[msgid]["role"]
+ + ">"
+ + ori_message[msgid]["content"]
+ + ""
+ + ori_message[msgid]["role"]
+ + ">\n"
+ )
+ input_ids = local_tokenizer(
+ input_prompt, return_tensors="pt", padding=True, truncation=True
+ ).to(device)
+ inputs = {k: v.to(local_model.device) for k, v in input_ids.items()}
+ output = local_model.generate(
+ **input_ids, max_new_tokens=512, num_return_sequences=1
+ )
+ response_text = local_tokenizer.decode(
+ output[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
+ )
+ # print("response_text",response_text)#修改
+ return response_text
+
+async def local_model_complete(
+ prompt: str,
+ system_prompt: str = None,
+ history_messages: list = [],
+ keyword_extraction: bool = False,
+ **kwargs,
+) -> str:
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
+ model_path = kwargs["hashing_kv"].global_config["llm_model_name"]
+ result = await local_model_if_cache(
+ model_path,
+ prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ **kwargs,
+ )
+ if keyword_extraction:
+ result= locate_json_string_body_from_string(result)
+ return result
+
+
+@lru_cache(maxsize=1)
+def initialize_vllm_model(model_name: str):
+ """加载 vLLM 聊天模型并缓存"""
+ model = LLM(model_name, device=device, max_model_len=8192)#修改
+ return model
+
+
+@retry(
+ stop=stop_after_attempt(3),
+ wait=wait_exponential(multiplier=1, min=4, max=10),
+ retry=retry_if_exception_type((RateLimitError, APIConnectionError, Timeout)),
+)
+async def vllm_model_if_cache(
+ model: str,
+ prompt: str,
+ system_prompt: str = None,
+ history_messages: list = [],
+ **kwargs,
+) -> str:
+ vllm_model = initialize_vllm_model(model)
+
+ # 构建多轮对话格式
+ messages = []
+ if system_prompt:
+ messages.append({"role": "system", "content": system_prompt})
+ messages.extend(history_messages)
+ messages.append({"role": "user", "content": prompt})
+ kwargs.pop("hashing_kv", None)
+ input_prompt = ""
+ try:
+ result = vllm_model.chat(messages)
+ except Exception:
+ try:
+ ori_message = copy.deepcopy(messages)
+ if messages[0]["role"] == "system":
+ messages[1]["content"] = (
+ ""
+ + messages[0]["content"]
+ + "\n"
+ + messages[1]["content"]
+ )
+ messages = messages[1:]
+ result = vllm_model.chat(messages)
+ except Exception:
+ len_message = len(ori_message)
+ for msgid in range(len_message):
+ result = (
+ result
+ + "<"
+ + ori_message[msgid]["role"]
+ + ">"
+ + ori_message[msgid]["content"]
+ + ""
+ + ori_message[msgid]["role"]
+ + ">\n"
+ )
+ return result[0].outputs[0].text
+
+
+
+async def vllm_model_complete(
+ prompt: str,
+ system_prompt: str = None,
+ history_messages: list = [],
+ keyword_extraction: bool = False,
+ **kwargs,
+) -> str:
+ keyword_extraction = kwargs.pop("keyword_extraction", None)
+ model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
+ result = await vllm_model_if_cache(
+ model_name,
+ prompt,
+ system_prompt=system_prompt,
+ history_messages=history_messages,
+ **kwargs,
+ )
+ if keyword_extraction:
+ return locate_json_string_body_from_string(result)
+ return result
+
+
+async def vllm_embedding(texts: list[str], tokenizer, embed_model) -> np.ndarray:
+ device = next(embed_model.parameters()).device
+ input_ids = tokenizer(
+ texts, return_tensors="pt", padding=True, truncation=True
+ ).input_ids.to(device)
+ with torch.no_grad():
+ outputs = embed_model(input_ids)
+ embeddings = outputs.last_hidden_state.mean(dim=1)
+ if embeddings.dtype == torch.bfloat16:
+ return embeddings.detach().to(torch.float32).cpu().numpy()
+ else:
+ return embeddings.detach().cpu().numpy()
+
+
+
+
+
+
+
+
+
+
class Model(BaseModel):
"""
This is a Pydantic model class named 'Model' that is used to define a custom language model.
@@ -1099,6 +1455,5 @@ async def llm_model_func(
async def main():
result = await gpt_4o_mini_complete("How are you?")
- print(result)
asyncio.run(main())
diff --git a/README.md b/README.md
index e7ad841..f7c246f 100644
--- a/README.md
+++ b/README.md
@@ -1,14 +1,30 @@
-The code for the paper **"PathRAG: Pruning Graph-based Retrieval Augmented Generation with Relational Paths"**.
+
+
+ PathRAG: Pruning Graph-based Retrieval Augmented Generation with Relational Paths
+
+
+
+
+
+
+
+
+
+
+
+
## Install
+Python Version: Python 3.10.18
```bash
cd PathRAG
pip install -e .
```
## Quick Start
-* You can quickly experience this project in the `v1_test.py` file.
-* Set OpenAI API key in environment if using OpenAI models: `api_key="sk-...".` in the `v1_test.py` and `llm.py` file
-* Prepare your retrieval document "text.txt".
-* Use the following Python snippet in the "v1_text.py" file to initialize PathRAG and perform queries.
+- use OpenAI API key
+- You can quickly experience this project in the `v1_test.py` file.
+- Set OpenAI API key in environment if using OpenAI models: `api_key="sk-...".` in the `v1_test.py` and `llm.py` file
+- Prepare your retrieval document "text.txt".
+- Use the following Python snippet in the `v1_test.py` file to initialize PathRAG and perform queries.
```python
import os
@@ -37,9 +53,53 @@ with open(data_file) as f:
print(rag.query(question, param=QueryParam(mode="hybrid")))
```
+## Quick Start with models from different sources
+- You can use the model from huggingface, ollama, modelscope, local and vllm
+- You can quickly experience this project in the `rag_test.py` file
+- Select your model source—— hf / vllm / ollama / ms / local
+- Prepare your llm_model,embedding_model and retrieval document "your data file" .
+- Use the following Python snippet in the `rag_test.py` file to use models from different sources
+- Detailed examples can be referred to`不同模型样例.txt`
+```Python
+import os
+import asyncio
+import torch
+from PathRAG.RAGRunner import RAGRunner
+if __name__ == "__main__":
+ backend = "your_model_source" # hf / vllm / ollama / ms / local可选
+ working_dir = f"your_working_dir" # 工作目录
+ llm_model_name = "Qwen/Qwen3-0.6B" # 聊天模型名称或路径
+ embedding_model_name = "iic/nlp_corom_sentence-embedding_english-base" # 编码模型名称或路径
+ # ollama 额外参数
+ llm_model_kwargs = {
+ "host": "http://localhost:11434",
+ "options": {"num_ctx": 8192},
+ "timeout": 300,
+ } if backend == "ollama" else {}
+
+ runner = RAGRunner(
+ backend=backend,
+ working_dir=working_dir,
+ llm_model_name=llm_model_name,
+ embedding_model_name=embedding_model_name,
+ llm_model_max_token_size=8192,
+ llm_model_kwargs=llm_model_kwargs,
+ embedding_dim=768,
+ embedding_max_token_size=5000,
+ )
+ data_file = "your_data_file"
+ question = "your_question"
+ with open(data_file, "r", encoding="utf-8") as f:
+ runner.insert_text(f.read())
+ answer = runner.query(question, mode="hybrid")
+ print("问:", question)
+ print("答:", answer)
+```
+
## Parameter modification
You can adjust the relevant parameters in the `base.py` and `operate.py` files.
-
+- You can change the hyperparameter `top_k` in `base.py`, where `top_k` represents the number of nodes retrieved
+- You can change the hyperparameters `alpha` and `threshold` in `operate.py`, where `alpha` represents the decay rate of information propagation along the edges, and `threshold` is the pruning threshold.
## Batch Insert
```python
import os
@@ -52,13 +112,75 @@ for file_name in txt_files:
rag.insert(file.read())
```
-## Cite
+## Evaluation
+
+### Dataset
+The dataset used in PathRAG can be downloaded from [TommyChien/UltraDomain](https://huggingface.co/datasets/TommyChien/UltraDomain).
+### Eval Metrics
+
+ Prompt
+
+```python
+You will evaluate two answers to the same question based on five criteria: **Comprehensiveness**, **Diversity**, **logicality**, **Coherence**, **Relevance**.
+
+- **Comprehensiveness**: How much detail does the answer provide to cover all aspects and details of the question?
+- **Diversity**: How varied and rich is the answer in providing different perspectives and insights on the question?
+- **logicality**: How logically does the answer respond to all parts of the question?
+- **Coherence**: How well does the answer maintain internal logical connections between its parts, ensuring a smooth and consistent structure?
+- **Relevance**: How relevant is the answer to the question, staying focused and addressing the intended topic or issue?
+
+For each criterion, choose the better answer (either Answer 1 or Answer 2) and explain why.
+
+Here is the question:
+{query}
+
+Here are the two answers:
+**Answer 1:**
+{answer1}
+
+**Answer 2:**
+{answer2}
+
+Evaluate both answers using the five criteria listed above and provide detailed explanations for each criterion.
+
+Output your evaluation in the following JSON format:
+{{
+ "Comprehensiveness": {{
+ "Winner": "[Answer 1 or Answer 2]",
+ "Explanation": "[Provide explanation here]"
+ }},
+ "Diversity": {{
+ "Winner": "[Answer 1 or Answer 2]",
+ "Explanation": "[Provide explanation here]"
+ }},
+ "logicality": {{
+ "Winner": "[Answer 1 or Answer 2]",
+ "Explanation": "[Provide explanation here]"
+ }},
+ "Coherence": {{
+ "Winner": "[Answer 1 or Answer 2]",
+ "Explanation": "[Provide explanation here]"
+ }},
+ "Relevance": {{
+ "Winner": "[Answer 1 or Answer 2]",
+ "Explanation": "[Provide explanation here]"
+ }}
+}}
+```
+
+
+
+
+
+## Contribution
+Boyu Chen, Zirui Guo, Zidan Yang, Junfei Bao
+## Citation
Please cite our paper if you use this code in your own work:
```python
@article{chen2025pathrag,
- title={PathRAG: Pruning Graph-based Retrieval Augmented Generation with Relational Paths},
- author={Chen, Boyu and Guo, Zirui and Yang, Zidan and Chen, Yuluo and Chen, Junze and Liu, Zhenghao and Shi, Chuan and Yang, Cheng},
- journal={arXiv preprint arXiv:2502.14902},
- year={2025}
+ title={PathRAG: Pruning Graph-based Retrieval Augmented Generation with Relational Paths},
+ author={Boyu Chen and Zirui Guo and Zidan Yang and Yuluo Chen and Junze Chen and Zhenghao Liu and Chuan Shi and Cheng Yang},
+ journal={arXiv preprint arXiv:2502.14902},
+ year={2025}
}
```