Skip to content

Commit 759b963

Browse files
afterimageximbajin
andauthored
feat(llm): support switch graph in api & add some query configs (#184)
TODO: we need wrapper the query configs --------- Co-authored-by: imbajin <[email protected]>
1 parent 3e0bf46 commit 759b963

File tree

11 files changed

+131
-50
lines changed

11 files changed

+131
-50
lines changed

.github/workflows/hugegraph-python-client.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ jobs:
2020
- name: Prepare HugeGraph Server Environment
2121
run: |
2222
docker run -d --name=graph -p 8080:8080 -e PASSWORD=admin hugegraph/hugegraph:1.3.0
23-
sleep 1
23+
sleep 5
2424
2525
- uses: actions/checkout@v4
2626

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ And here are links of other repositories:
3737

3838
- Welcome to contribute to HugeGraph, please see [Guidelines](https://hugegraph.apache.org/docs/contribution-guidelines/) for more information.
3939
- Note: It's recommended to use [GitHub Desktop](https://desktop.github.com/) to greatly simplify the PR and commit process.
40-
- Code format: Please run [`./style/code_format_and_analysis.sh`](style/code_format_and_analysis.sh) to format your code before submitting a PR.
40+
- Code format: Please run [`./style/code_format_and_analysis.sh`](style/code_format_and_analysis.sh) to format your code before submitting a PR. (Use `pylint` to check code style)
4141
- Thank you to all the people who already contributed to HugeGraph!
4242

4343
[![contributors graph](https://contrib.rocks/image?repo=apache/incubator-hugegraph-ai)](https://github.com/apache/incubator-hugegraph-ai/graphs/contributors)

hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,17 @@
2323
from hugegraph_llm.config import prompt
2424

2525

26+
class GraphConfigRequest(BaseModel):
27+
ip: str = Query('127.0.0.1', description="hugegraph client ip.")
28+
port: str = Query('8080', description="hugegraph client port.")
29+
name: str = Query('hugegraph', description="hugegraph client name.")
30+
user: str = Query('', description="hugegraph client user.")
31+
pwd: str = Query('', description="hugegraph client pwd.")
32+
gs: str = None
33+
34+
2635
class RAGRequest(BaseModel):
27-
query: str = Query("", description="Query you want to ask")
36+
query: str = Query(..., description="Query you want to ask")
2837
raw_answer: bool = Query(False, description="Use LLM to generate answer directly")
2938
vector_only: bool = Query(False, description="Use LLM to generate answer with vector")
3039
graph_only: bool = Query(True, description="Use LLM to generate answer with graph RAG only")
@@ -33,6 +42,16 @@ class RAGRequest(BaseModel):
3342
rerank_method: Literal["bleu", "reranker"] = Query("bleu", description="Method to rerank the results.")
3443
near_neighbor_first: bool = Query(False, description="Prioritize near neighbors in the search results.")
3544
custom_priority_info: str = Query("", description="Custom information to prioritize certain results.")
45+
# Graph Configs
46+
max_graph_items: int = Query(30, description="Maximum number of items for GQL queries in graph.")
47+
topk_return_results: int = Query(20, description="Number of sorted results to return finally.")
48+
vector_dis_threshold: float = Query(0.9, description="Threshold for vector similarity\
49+
(results greater than this will be ignored).")
50+
topk_per_keyword : int = Query(1, description="TopK results returned for each keyword \
51+
extracted from the query, by default only the most similar one is returned.")
52+
client_config: Optional[GraphConfigRequest] = Query(None, description="hugegraph server config.")
53+
54+
# Keep prompt params in the end
3655
answer_prompt: Optional[str] = Query(prompt.answer_prompt, description="Prompt to guide the answer generation.")
3756
keywords_extract_prompt: Optional[str] = Query(
3857
prompt.keywords_extract_prompt,
@@ -47,7 +66,18 @@ class RAGRequest(BaseModel):
4766

4867
# TODO: import the default value of prompt.* dynamically
4968
class GraphRAGRequest(BaseModel):
50-
query: str = Query("", description="Query you want to ask")
69+
query: str = Query(..., description="Query you want to ask")
70+
# Graph Configs
71+
max_graph_items: int = Query(30, description="Maximum number of items for GQL queries in graph.")
72+
topk_return_results: int = Query(20, description="Number of sorted results to return finally.")
73+
vector_dis_threshold: float = Query(0.9, description="Threshold for vector similarity \
74+
(results greater than this will be ignored).")
75+
topk_per_keyword : int = Query(1, description="TopK results returned for each keyword extracted\
76+
from the query, by default only the most similar one is returned.")
77+
78+
client_config : Optional[GraphConfigRequest] = Query(None, description="hugegraph server config.")
79+
get_vid_only: bool = Query(False, description="return only keywords & vid (early stop).")
80+
5181
gremlin_tmpl_num: int = Query(
5282
1, description="Number of Gremlin templates to use. If num <=0 means template is not provided"
5383
)
@@ -60,15 +90,6 @@ class GraphRAGRequest(BaseModel):
6090
)
6191

6292

63-
class GraphConfigRequest(BaseModel):
64-
ip: str = "127.0.0.1"
65-
port: str = "8080"
66-
name: str = "hugegraph"
67-
user: str = "xxx"
68-
pwd: str = "xxx"
69-
gs: str = None
70-
71-
7293
class LLMConfigRequest(BaseModel):
7394
llm_type: str
7495
# The common parameters shared by OpenAI, Qianfan Wenxin,

hugegraph-llm/src/hugegraph_llm/api/rag_api.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
RerankerConfigRequest,
2828
GraphRAGRequest,
2929
)
30+
from hugegraph_llm.config import huge_settings
3031
from hugegraph_llm.api.models.rag_response import RAGResponse
3132
from hugegraph_llm.config import llm_settings, prompt
3233
from hugegraph_llm.utils.log import log
@@ -43,6 +44,8 @@ def rag_http_api(
4344
):
4445
@router.post("/rag", status_code=status.HTTP_200_OK)
4546
def rag_answer_api(req: RAGRequest):
47+
set_graph_config(req)
48+
4649
result = rag_answer_func(
4750
text=req.query,
4851
raw_answer=req.raw_answer,
@@ -52,10 +55,15 @@ def rag_answer_api(req: RAGRequest):
5255
graph_ratio=req.graph_ratio,
5356
rerank_method=req.rerank_method,
5457
near_neighbor_first=req.near_neighbor_first,
58+
gremlin_tmpl_num=req.gremlin_tmpl_num,
59+
max_graph_items=req.max_graph_items,
60+
topk_return_results=req.topk_return_results,
61+
vector_dis_threshold=req.vector_dis_threshold,
62+
topk_per_keyword=req.topk_per_keyword,
63+
# Keep prompt params in the end
5564
custom_related_information=req.custom_priority_info,
5665
answer_prompt=req.answer_prompt or prompt.answer_prompt,
5766
keywords_extract_prompt=req.keywords_extract_prompt or prompt.keywords_extract_prompt,
58-
gremlin_tmpl_num=req.gremlin_tmpl_num,
5967
gremlin_prompt=req.gremlin_prompt or prompt.gremlin_generate_prompt,
6068
)
6169
# TODO: we need more info in the response for users to understand the query logic
@@ -68,16 +76,32 @@ def rag_answer_api(req: RAGRequest):
6876
},
6977
}
7078

79+
def set_graph_config(req):
80+
if req.client_config:
81+
huge_settings.graph_ip = req.client_config.ip
82+
huge_settings.graph_port = req.client_config.port
83+
huge_settings.graph_name = req.client_config.name
84+
huge_settings.graph_user = req.client_config.user
85+
huge_settings.graph_pwd = req.client_config.pwd
86+
huge_settings.graph_space = req.client_config.gs
87+
7188
@router.post("/rag/graph", status_code=status.HTTP_200_OK)
7289
def graph_rag_recall_api(req: GraphRAGRequest):
7390
try:
91+
set_graph_config(req)
92+
7493
result = graph_rag_recall_func(
7594
query=req.query,
95+
max_graph_items=req.max_graph_items,
96+
topk_return_results=req.topk_return_results,
97+
vector_dis_threshold=req.vector_dis_threshold,
98+
topk_per_keyword=req.topk_per_keyword,
7699
gremlin_tmpl_num=req.gremlin_tmpl_num,
77100
rerank_method=req.rerank_method,
78101
near_neighbor_first=req.near_neighbor_first,
79102
custom_related_information=req.custom_priority_info,
80103
gremlin_prompt=req.gremlin_prompt or prompt.gremlin_generate_prompt,
104+
get_vid_only=req.get_vid_only
81105
)
82106

83107
if isinstance(result, dict):

hugegraph-llm/src/hugegraph_llm/demo/rag_demo/rag_block.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ def rag_answer(
4343
keywords_extract_prompt: str,
4444
gremlin_tmpl_num: Optional[int] = 2,
4545
gremlin_prompt: Optional[str] = None,
46+
max_graph_items=30,
47+
topk_return_results=20,
48+
vector_dis_threshold=0.9,
49+
topk_per_keyword=1,
4650
) -> Tuple:
4751
"""
4852
Generate an answer using the RAG (Retrieval-Augmented Generation) pipeline.
@@ -79,22 +83,28 @@ def rag_answer(
7983
if vector_search:
8084
rag.query_vector_index()
8185
if graph_search:
82-
rag.extract_keywords(extract_template=keywords_extract_prompt).keywords_to_vid().import_schema(
86+
rag.extract_keywords(extract_template=keywords_extract_prompt).keywords_to_vid(
87+
vector_dis_threshold=vector_dis_threshold,
88+
topk_per_keyword=topk_per_keyword,
89+
).import_schema(
8390
huge_settings.graph_name
8491
).query_graphdb(
8592
num_gremlin_generate_example=gremlin_tmpl_num,
8693
gremlin_prompt=gremlin_prompt,
94+
max_graph_items=max_graph_items
8795
)
8896
# TODO: add more user-defined search strategies
8997
rag.merge_dedup_rerank(
90-
graph_ratio,
91-
rerank_method,
92-
near_neighbor_first,
98+
graph_ratio=graph_ratio,
99+
rerank_method=rerank_method,
100+
near_neighbor_first=near_neighbor_first,
101+
topk_return_results=topk_return_results
93102
)
94103
rag.synthesize_answer(raw_answer, vector_only_answer, graph_only_answer, graph_vector_answer, answer_prompt)
95104

96105
try:
97-
context = rag.run(verbose=True, query=text, vector_search=vector_search, graph_search=graph_search)
106+
context = rag.run(verbose=True, query=text, vector_search=vector_search, graph_search=graph_search,
107+
max_graph_items=max_graph_items)
98108
if context.get("switch_to_bleu"):
99109
gr.Warning("Online reranker fails, automatically switches to local bleu rerank.")
100110
return (

hugegraph-llm/src/hugegraph_llm/demo/rag_demo/text2gremlin_block.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -188,17 +188,28 @@ def graph_rag_recall(
188188
near_neighbor_first: bool,
189189
custom_related_information: str,
190190
gremlin_prompt: str,
191+
max_graph_items: int,
192+
topk_return_results: int,
193+
vector_dis_threshold: float,
194+
topk_per_keyword: int,
195+
get_vid_only: bool
191196
) -> dict:
192197
store_schema(prompt.text2gql_graph_schema, query, gremlin_prompt)
193198
rag = RAGPipeline()
194-
195-
rag.extract_keywords().keywords_to_vid().import_schema(huge_settings.graph_name).query_graphdb(
196-
num_gremlin_generate_example=gremlin_tmpl_num,
197-
gremlin_prompt=gremlin_prompt,
198-
).merge_dedup_rerank(
199-
rerank_method=rerank_method,
200-
near_neighbor_first=near_neighbor_first,
201-
custom_related_information=custom_related_information,
202-
)
199+
rag.extract_keywords().keywords_to_vid(
200+
vector_dis_threshold=vector_dis_threshold,
201+
topk_per_keyword=topk_per_keyword,
202+
)
203+
if not get_vid_only:
204+
rag.import_schema(huge_settings.graph_name).query_graphdb(
205+
num_gremlin_generate_example=gremlin_tmpl_num,
206+
gremlin_prompt=gremlin_prompt,
207+
max_graph_items=max_graph_items,
208+
).merge_dedup_rerank(
209+
rerank_method=rerank_method,
210+
near_neighbor_first=near_neighbor_first,
211+
custom_related_information=custom_related_information,
212+
topk_return_results=topk_return_results,
213+
)
203214
context = rag.run(verbose=True, query=query, graph_search=True)
204215
return context

hugegraph-llm/src/hugegraph_llm/operators/common_op/merge_dedup_rerank.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class MergeDedupRerank:
4444
def __init__(
4545
self,
4646
embedding: BaseEmbedding,
47-
topk: int = huge_settings.topk_return_results,
47+
topk_return_results: int = huge_settings.topk_return_results,
4848
graph_ratio: float = 0.5,
4949
method: Literal["bleu", "reranker"] = "bleu",
5050
near_neighbor_first: bool = False,
@@ -54,7 +54,7 @@ def __init__(
5454
assert method in ["bleu", "reranker"], f"Unimplemented rerank method '{method}'."
5555
self.embedding = embedding
5656
self.graph_ratio = graph_ratio
57-
self.topk = topk
57+
self.topk_return_results = topk_return_results
5858
self.method = method
5959
self.near_neighbor_first = near_neighbor_first
6060
self.custom_related_information = custom_related_information
@@ -70,11 +70,11 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
7070
vector_search = context.get("vector_search", False)
7171
graph_search = context.get("graph_search", False)
7272
if graph_search and vector_search:
73-
graph_length = int(self.topk * self.graph_ratio)
74-
vector_length = self.topk - graph_length
73+
graph_length = int(self.topk_return_results * self.graph_ratio)
74+
vector_length = self.topk_return_results - graph_length
7575
else:
76-
graph_length = self.topk
77-
vector_length = self.topk
76+
graph_length = self.topk_return_results
77+
vector_length = self.topk_return_results
7878

7979
vector_result = context.get("vector_result", [])
8080
vector_length = min(len(vector_result), vector_length)

hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,14 @@ def keywords_to_vid(
100100
by: Literal["query", "keywords"] = "keywords",
101101
topk_per_keyword: int = huge_settings.topk_per_keyword,
102102
topk_per_query: int = 10,
103+
vector_dis_threshold: float = huge_settings.vector_dis_threshold,
103104
):
104105
"""
105106
Add a semantic ID query operator to the pipeline.
106107
:param by: Match by query or keywords.
107108
:param topk_per_keyword: Top K results per keyword.
108109
:param topk_per_query: Top K results per query.
110+
:param vector_dis_threshold: Vector distance threshold.
109111
:return: Self-instance for chaining.
110112
"""
111113
self._operators.append(
@@ -114,6 +116,7 @@ def keywords_to_vid(
114116
by=by,
115117
topk_per_keyword=topk_per_keyword,
116118
topk_per_query=topk_per_query,
119+
vector_dis_threshold=vector_dis_threshold,
117120
)
118121
)
119122
return self
@@ -174,6 +177,7 @@ def merge_dedup_rerank(
174177
rerank_method: Literal["bleu", "reranker"] = "bleu",
175178
near_neighbor_first: bool = False,
176179
custom_related_information: str = "",
180+
topk_return_results: int = huge_settings.topk_return_results,
177181
):
178182
"""
179183
Add a merge, deduplication, and rerank operator to the pipeline.
@@ -187,6 +191,7 @@ def merge_dedup_rerank(
187191
method=rerank_method,
188192
near_neighbor_first=near_neighbor_first,
189193
custom_related_information=custom_related_information,
194+
topk_return_results=topk_return_results
190195
)
191196
)
192197
return self
@@ -239,7 +244,9 @@ def run(self, **kwargs) -> Dict[str, Any]:
239244
:return: Final context after all operators have been executed.
240245
"""
241246
if len(self._operators) == 0:
242-
self.extract_keywords().query_graphdb().synthesize_answer()
247+
self.extract_keywords().query_graphdb(
248+
max_graph_items=kwargs.get('max_graph_items')
249+
).synthesize_answer()
243250

244251
context = kwargs
245252

hugegraph-llm/src/hugegraph_llm/operators/index_op/semantic_id_query.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,16 @@ def __init__(
3434
embedding: BaseEmbedding,
3535
by: Literal["query", "keywords"] = "keywords",
3636
topk_per_query: int = 10,
37-
topk_per_keyword: int = huge_settings.topk_per_keyword
37+
topk_per_keyword: int = huge_settings.topk_per_keyword,
38+
vector_dis_threshold: float = huge_settings.vector_dis_threshold,
3839
):
3940
self.index_dir = str(os.path.join(resource_path, huge_settings.graph_name, "graph_vids"))
4041
self.vector_index = VectorIndex.from_index_file(self.index_dir)
4142
self.embedding = embedding
4243
self.by = by
4344
self.topk_per_query = topk_per_query
4445
self.topk_per_keyword = topk_per_keyword
46+
self.vector_dis_threshold = vector_dis_threshold
4547
self._client = PyHugeClient(
4648
huge_settings.graph_ip,
4749
huge_settings.graph_port,
@@ -76,7 +78,7 @@ def _fuzzy_match_vids(self, keywords: List[str]) -> List[str]:
7678
for keyword in keywords:
7779
keyword_vector = self.embedding.get_text_embedding(keyword)
7880
results = self.vector_index.search(keyword_vector, top_k=self.topk_per_keyword,
79-
dis_threshold=float(huge_settings.vector_dis_threshold))
81+
dis_threshold=float(self.vector_dis_threshold))
8082
if results:
8183
fuzzy_match_result.extend(results[:self.topk_per_keyword])
8284
return fuzzy_match_result

0 commit comments

Comments
 (0)