Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class RAGRequest(BaseModel):
prompt.gremlin_generate_prompt,
description="Prompt for the Text2Gremlin query.",
)
stream: bool = Query(False, description="Enable streaming response")


# TODO: import the default value of prompt.* dynamically
Expand All @@ -58,6 +59,7 @@ class GraphRAGRequest(BaseModel):
prompt.gremlin_generate_prompt,
description="Prompt for the Text2Gremlin query.",
)
stream: bool = Query(False, description="Enable streaming response")


class GraphConfigRequest(BaseModel):
Expand Down Expand Up @@ -94,4 +96,4 @@ class RerankerConfigRequest(BaseModel):

class LogStreamRequest(BaseModel):
admin_token: Optional[str] = None
log_file: Optional[str] = "llm-server.log"
log_file: Optional[str] = "llm-server.log"
296 changes: 233 additions & 63 deletions hugegraph-llm/src/hugegraph_llm/api/rag_api.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, could tell me how u test the APIs? By directly request them?

The gradio UI loss the API link now

Before:
image

Now:
image

Maybe refer here: (Or Gradio's mount doc?)

Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
# under the License.

import json
import asyncio
from typing import AsyncGenerator

from fastapi import status, APIRouter, HTTPException
from fastapi.responses import StreamingResponse

from hugegraph_llm.api.exceptions.rag_exceptions import generate_response
from hugegraph_llm.api.models.rag_requests import (
Expand All @@ -33,76 +36,243 @@


def rag_http_api(
router: APIRouter,
rag_answer_func,
graph_rag_recall_func,
apply_graph_conf,
apply_llm_conf,
apply_embedding_conf,
apply_reranker_conf,
router: APIRouter,
rag_answer_func,
graph_rag_recall_func,
apply_graph_conf,
apply_llm_conf,
apply_embedding_conf,
apply_reranker_conf,
rag_answer_stream_func=None,
graph_rag_recall_stream_func=None,
):
@router.post("/rag", status_code=status.HTTP_200_OK)
def rag_answer_api(req: RAGRequest):
result = rag_answer_func(
text=req.query,
raw_answer=req.raw_answer,
vector_only_answer=req.vector_only,
graph_only_answer=req.graph_only,
graph_vector_answer=req.graph_vector_answer,
graph_ratio=req.graph_ratio,
rerank_method=req.rerank_method,
near_neighbor_first=req.near_neighbor_first,
custom_related_information=req.custom_priority_info,
answer_prompt=req.answer_prompt or prompt.answer_prompt,
keywords_extract_prompt=req.keywords_extract_prompt or prompt.keywords_extract_prompt,
gremlin_tmpl_num=req.gremlin_tmpl_num,
gremlin_prompt=req.gremlin_prompt or prompt.gremlin_generate_prompt,
)
# TODO: we need more info in the response for users to understand the query logic
return {
"query": req.query,
**{
key: value
for key, value in zip(["raw_answer", "vector_only", "graph_only", "graph_vector_answer"], result)
if getattr(req, key)
},
}
async def stream_rag_answer(
text,
raw_answer,
vector_only_answer,
graph_only_answer,
graph_vector_answer,
graph_ratio,
rerank_method,
near_neighbor_first,
custom_related_information,
answer_prompt,
keywords_extract_prompt,
gremlin_tmpl_num,
gremlin_prompt,
) -> AsyncGenerator[str, None]:
"""
Stream the RAG answer results
"""
if rag_answer_stream_func:
# If a streaming-specific function exists, use it
async for chunk in rag_answer_stream_func(
text=text,
raw_answer=raw_answer,
vector_only_answer=vector_only_answer,
graph_only_answer=graph_only_answer,
graph_vector_answer=graph_vector_answer,
graph_ratio=graph_ratio,
rerank_method=rerank_method,
near_neighbor_first=near_neighbor_first,
custom_related_information=custom_related_information,
answer_prompt=answer_prompt,
keywords_extract_prompt=keywords_extract_prompt,
gremlin_tmpl_num=gremlin_tmpl_num,
gremlin_prompt=gremlin_prompt,
):
yield f"data: {json.dumps({'chunk': chunk})}\n\n"
else:
# Otherwise, use the normal function but adapt it for streaming
# by sending the entire result at once
result = rag_answer_func(
text=text,
raw_answer=raw_answer,
vector_only_answer=vector_only_answer,
graph_only_answer=graph_only_answer,
graph_vector_answer=graph_vector_answer,
graph_ratio=graph_ratio,
rerank_method=rerank_method,
near_neighbor_first=near_neighbor_first,
custom_related_information=custom_related_information,
answer_prompt=answer_prompt,
keywords_extract_prompt=keywords_extract_prompt,
gremlin_tmpl_num=gremlin_tmpl_num,
gremlin_prompt=gremlin_prompt,
)

@router.post("/rag/graph", status_code=status.HTTP_200_OK)
def graph_rag_recall_api(req: GraphRAGRequest):
try:
result = graph_rag_recall_func(
query=req.query,
gremlin_tmpl_num=req.gremlin_tmpl_num,
response_data = {
"query": text,
**{
key: value
for key, value in zip(["raw_answer", "vector_only", "graph_only", "graph_vector_answer"], result)
if eval(key) # Convert string to boolean
},
}

yield f"data: {json.dumps(response_data)}\n\n"
# Signal end of stream
yield "data: [DONE]\n\n"

async def stream_graph_rag_recall(
query,
gremlin_tmpl_num,
rerank_method,
near_neighbor_first,
custom_related_information,
gremlin_prompt,
) -> AsyncGenerator[str, None]:
"""
Stream the graph RAG recall results
"""
if graph_rag_recall_stream_func:
# If a streaming-specific function exists, use it
async for chunk in graph_rag_recall_stream_func(
query=query,
gremlin_tmpl_num=gremlin_tmpl_num,
rerank_method=rerank_method,
near_neighbor_first=near_neighbor_first,
custom_related_information=custom_related_information,
gremlin_prompt=gremlin_prompt,
):
yield f"data: {json.dumps({'chunk': chunk})}\n\n"
else:
# Otherwise, use the normal function but adapt it for streaming
try:
result = graph_rag_recall_func(
query=query,
gremlin_tmpl_num=gremlin_tmpl_num,
rerank_method=rerank_method,
near_neighbor_first=near_neighbor_first,
custom_related_information=custom_related_information,
gremlin_prompt=gremlin_prompt,
)

if isinstance(result, dict):
params = [
"query",
"keywords",
"match_vids",
"graph_result_flag",
"gremlin",
"graph_result",
"vertex_degree_list",
]
user_result = {key: result[key] for key in params if key in result}
yield f"data: {json.dumps({'graph_recall': user_result})}\n\n"
else:
# Note: Maybe only for qianfan/wenxin
yield f"data: {json.dumps({'graph_recall': json.dumps(result)})}\n\n"

# Signal end of stream
yield "data: [DONE]\n\n"

except TypeError as e:
log.error("TypeError in stream_graph_rag_recall: %s", e)
yield f"data: {json.dumps({'error': str(e), 'status': 400})}\n\n"
except Exception as e:
log.error("Unexpected error occurred: %s", e)
yield f"data: {json.dumps({'error': 'An unexpected error occurred.', 'status': 500})}\n\n"

@router.post("/rag", status_code=status.HTTP_200_OK)
async def rag_answer_api(req: RAGRequest):
if req.stream:
# Return a streaming response
return StreamingResponse(
stream_rag_answer(
text=req.query,
raw_answer=req.raw_answer,
vector_only_answer=req.vector_only,
graph_only_answer=req.graph_only,
graph_vector_answer=req.graph_vector_answer,
graph_ratio=req.graph_ratio,
rerank_method=req.rerank_method,
near_neighbor_first=req.near_neighbor_first,
custom_related_information=req.custom_priority_info,
answer_prompt=req.answer_prompt or prompt.answer_prompt,
keywords_extract_prompt=req.keywords_extract_prompt or prompt.keywords_extract_prompt,
gremlin_tmpl_num=req.gremlin_tmpl_num,
gremlin_prompt=req.gremlin_prompt or prompt.gremlin_generate_prompt,
),
media_type="text/event-stream",
)
else:
# Synchronous response (original behavior)
result = rag_answer_func(
text=req.query,
raw_answer=req.raw_answer,
vector_only_answer=req.vector_only,
graph_only_answer=req.graph_only,
graph_vector_answer=req.graph_vector_answer,
graph_ratio=req.graph_ratio,
rerank_method=req.rerank_method,
near_neighbor_first=req.near_neighbor_first,
custom_related_information=req.custom_priority_info,
answer_prompt=req.answer_prompt or prompt.answer_prompt,
keywords_extract_prompt=req.keywords_extract_prompt or prompt.keywords_extract_prompt,
gremlin_tmpl_num=req.gremlin_tmpl_num,
gremlin_prompt=req.gremlin_prompt or prompt.gremlin_generate_prompt,
)
# TODO: we need more info in the response for users to understand the query logic
return {
"query": req.query,
**{
key: value
for key, value in zip(["raw_answer", "vector_only", "graph_only", "graph_vector_answer"], result)
if getattr(req, key)
},
}

if isinstance(result, dict):
params = [
"query",
"keywords",
"match_vids",
"graph_result_flag",
"gremlin",
"graph_result",
"vertex_degree_list",
]
user_result = {key: result[key] for key in params if key in result}
return {"graph_recall": user_result}
# Note: Maybe only for qianfan/wenxin
return {"graph_recall": json.dumps(result)}

except TypeError as e:
log.error("TypeError in graph_rag_recall_api: %s", e)
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
except Exception as e:
log.error("Unexpected error occurred: %s", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred."
) from e
@router.post("/rag/graph", status_code=status.HTTP_200_OK)
async def graph_rag_recall_api(req: GraphRAGRequest):
if req.stream:
# Return a streaming response
return StreamingResponse(
stream_graph_rag_recall(
query=req.query,
gremlin_tmpl_num=req.gremlin_tmpl_num,
rerank_method=req.rerank_method,
near_neighbor_first=req.near_neighbor_first,
custom_related_information=req.custom_priority_info,
gremlin_prompt=req.gremlin_prompt or prompt.gremlin_generate_prompt,
),
media_type="text/event-stream",
)
else:
# Synchronous response (original behavior)
try:
result = graph_rag_recall_func(
query=req.query,
gremlin_tmpl_num=req.gremlin_tmpl_num,
rerank_method=req.rerank_method,
near_neighbor_first=req.near_neighbor_first,
custom_related_information=req.custom_priority_info,
gremlin_prompt=req.gremlin_prompt or prompt.gremlin_generate_prompt,
)

if isinstance(result, dict):
params = [
"query",
"keywords",
"match_vids",
"graph_result_flag",
"gremlin",
"graph_result",
"vertex_degree_list",
]
user_result = {key: result[key] for key in params if key in result}
return {"graph_recall": user_result}
# Note: Maybe only for qianfan/wenxin
return {"graph_recall": json.dumps(result)}

except TypeError as e:
log.error("TypeError in graph_rag_recall_api: %s", e)
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
except Exception as e:
log.error("Unexpected error occurred: %s", e)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred."
) from e

@router.post("/config/graph", status_code=status.HTTP_201_CREATED)
def graph_config_api(req: GraphConfigRequest):
Expand Down Expand Up @@ -145,4 +315,4 @@
res = apply_reranker_conf(req.api_key, req.reranker_model, None, origin_call="http")
else:
res = status.HTTP_501_NOT_IMPLEMENTED
return generate_response(RAGResponse(status_code=res, message="Missing Value"))
return generate_response(RAGResponse(status_code=res, message="Missing Value"))
Loading