diff --git a/.github/workflows/hugegraph-python-client.yml b/.github/workflows/python-client.yml similarity index 100% rename from .github/workflows/hugegraph-python-client.yml rename to .github/workflows/python-client.yml diff --git a/hugegraph-llm/src/hugegraph_llm/api/admin_api.py b/hugegraph-llm/src/hugegraph_llm/api/admin_api.py index a234c5023..bc316da90 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/admin_api.py +++ b/hugegraph-llm/src/hugegraph_llm/api/admin_api.py @@ -16,21 +16,18 @@ # under the License. import os -from fastapi import status, APIRouter +from fastapi import status, APIRouter, HTTPException from fastapi.responses import StreamingResponse -from hugegraph_llm.api.exceptions.rag_exceptions import generate_response from hugegraph_llm.api.models.rag_requests import LogStreamRequest -from hugegraph_llm.api.models.rag_response import RAGResponse from hugegraph_llm.config import admin_settings -# FIXME: line 31: E0702: Raising dict while only classes or instances are allowed (raising-bad-type) def admin_http_api(router: APIRouter, log_stream): @router.post("/logs", status_code=status.HTTP_200_OK) async def log_stream_api(req: LogStreamRequest): if admin_settings.admin_token != req.admin_token: - raise generate_response(RAGResponse(status_code=status.HTTP_403_FORBIDDEN, message="Invalid admin_token")) #pylint: disable=E0702 + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Invalid admin_token") log_path = os.path.join("logs", req.log_file) # Create a StreamingResponse that reads from the log stream generator diff --git a/hugegraph-llm/src/hugegraph_llm/api/config_api.py b/hugegraph-llm/src/hugegraph_llm/api/config_api.py new file mode 100644 index 000000000..c6b43111f --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/api/config_api.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from fastapi import status, APIRouter + +from hugegraph_llm.api.exceptions.rag_exceptions import generate_response +from hugegraph_llm.api.models.rag_requests import ( + GraphConfigRequest, + LLMConfigRequest, + RerankerConfigRequest, +) +from hugegraph_llm.api.models.rag_response import RAGResponse +from hugegraph_llm.config import llm_settings + + +async def graph_config_route(router: APIRouter, apply_graph_conf): + @router.post("/config/graph", status_code=status.HTTP_201_CREATED) + async def graph_config_api(req: GraphConfigRequest): + # Accept status code + res = await apply_graph_conf(req.ip, req.port, req.name, req.user, req.pwd, req.gs, origin_call="http") + return generate_response(RAGResponse(status_code=res, message="Missing Value")) + return graph_config_api + +async def llm_config_route(router: APIRouter, apply_llm_conf): + # TODO: restructure the implement of llm to three types, like "/config/chat_llm" + /config/mini_task_llm + .. + @router.post("/config/llm", status_code=status.HTTP_201_CREATED) + async def llm_config_api(req: LLMConfigRequest): + llm_settings.llm_type = req.llm_type + + if req.llm_type == "openai": + res = await apply_llm_conf(req.api_key, req.api_base, req.language_model, req.max_tokens, + origin_call="http") + elif req.llm_type == "qianfan_wenxin": + res = await apply_llm_conf(req.api_key, req.secret_key, req.language_model, None, origin_call="http") + else: + res = await apply_llm_conf(req.host, req.port, req.language_model, None, origin_call="http") + return generate_response(RAGResponse(status_code=res, message="Missing Value")) + + return llm_config_api + +async def embedding_config_route(router: APIRouter, apply_embedding_conf): + @router.post("/config/embedding", status_code=status.HTTP_201_CREATED) + async def embedding_config_api(req: LLMConfigRequest): + llm_settings.embedding_type = req.llm_type + + if req.llm_type == "openai": + res = await apply_embedding_conf(req.api_key, req.api_base, req.language_model, origin_call="http") + elif req.llm_type == "qianfan_wenxin": + res = await apply_embedding_conf(req.api_key, req.api_base, None, origin_call="http") + else: + res = await apply_embedding_conf(req.host, req.port, req.language_model, origin_call="http") + return generate_response(RAGResponse(status_code=res, message="Missing Value")) + + return embedding_config_api + +async def rerank_config_route(router: APIRouter, apply_reranker_conf): + @router.post("/config/rerank", status_code=status.HTTP_201_CREATED) + async def rerank_config_api(req: RerankerConfigRequest): + llm_settings.reranker_type = req.reranker_type + + if req.reranker_type == "cohere": + res = await apply_reranker_conf(req.api_key, req.reranker_model, req.cohere_base_url, origin_call="http") + elif req.reranker_type == "siliconflow": + res = await apply_reranker_conf(req.api_key, req.reranker_model, None, origin_call="http") + else: + res = status.HTTP_501_NOT_IMPLEMENTED + return generate_response(RAGResponse(status_code=res, message="Missing Value")) + + return rerank_config_api + + +async def config_http_api( + router: APIRouter, + apply_graph_conf, + apply_llm_conf, + apply_embedding_conf, + apply_reranker_conf, +): + await graph_config_route(router, apply_graph_conf) + await llm_config_route(router, apply_llm_conf) + await embedding_config_route(router, apply_embedding_conf) + await rerank_config_route(router, apply_reranker_conf) diff --git a/hugegraph-llm/src/hugegraph_llm/api/exceptions/rag_exceptions.py b/hugegraph-llm/src/hugegraph_llm/api/exceptions/rag_exceptions.py index 75eb14cf3..993c495d5 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/exceptions/rag_exceptions.py +++ b/hugegraph-llm/src/hugegraph_llm/api/exceptions/rag_exceptions.py @@ -21,7 +21,8 @@ class ExternalException(HTTPException): def __init__(self): - super().__init__(status_code=400, detail="Connect failed with error code -1, please check the input.") + super().__init__(status_code=400, detail="Connect failed with error code -1, " + "please check the input.") class ConnectionFailedException(HTTPException): diff --git a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py index a6b58b460..bc5b3a030 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py +++ b/hugegraph-llm/src/hugegraph_llm/api/models/rag_requests.py @@ -50,6 +50,7 @@ class RAGRequest(BaseModel): topk_per_keyword : int = Query(1, description="TopK results returned for each keyword \ extracted from the query, by default only the most similar one is returned.") client_config: Optional[GraphConfigRequest] = Query(None, description="hugegraph server config.") + stream: bool = Query(False, description="Whether to use streaming response") # Keep prompt params in the end answer_prompt: Optional[str] = Query(prompt.answer_prompt, description="Prompt to guide the answer generation.") @@ -77,6 +78,7 @@ class GraphRAGRequest(BaseModel): client_config : Optional[GraphConfigRequest] = Query(None, description="hugegraph server config.") get_vertex_only: bool = Query(False, description="return only keywords & vertex (early stop).") + stream: bool = Query(False, description="Whether to use streaming response") gremlin_tmpl_num: int = Query( 1, description="Number of Gremlin templates to use. If num <=0 means template is not provided" diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py index 04c7b9a51..14e5867bc 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py +++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py @@ -19,34 +19,36 @@ from fastapi import status, APIRouter, HTTPException -from hugegraph_llm.api.exceptions.rag_exceptions import generate_response from hugegraph_llm.api.models.rag_requests import ( RAGRequest, - GraphConfigRequest, - LLMConfigRequest, - RerankerConfigRequest, GraphRAGRequest, ) from hugegraph_llm.config import huge_settings -from hugegraph_llm.api.models.rag_response import RAGResponse -from hugegraph_llm.config import llm_settings, prompt +from hugegraph_llm.config import prompt from hugegraph_llm.utils.log import log +from hugegraph_llm.api.config_api import ( + graph_config_route, + llm_config_route, + embedding_config_route, + rerank_config_route +) + # pylint: disable=too-many-statements -def rag_http_api( +async def rag_http_api( router: APIRouter, rag_answer_func, graph_rag_recall_func, - apply_graph_conf, - apply_llm_conf, - apply_embedding_conf, - apply_reranker_conf, + apply_graph_conf= None, + apply_llm_conf= None, + apply_embedding_conf= None, + apply_reranker_conf= None, ): @router.post("/rag", status_code=status.HTTP_200_OK) - def rag_answer_api(req: RAGRequest): + async def rag_answer_api(req: RAGRequest): set_graph_config(req) - result = rag_answer_func( + result = await rag_answer_func( text=req.query, raw_answer=req.raw_answer, vector_only_answer=req.vector_only, @@ -86,11 +88,11 @@ def set_graph_config(req): huge_settings.graph_space = req.client_config.gs @router.post("/rag/graph", status_code=status.HTTP_200_OK) - def graph_rag_recall_api(req: GraphRAGRequest): + async def graph_rag_recall_api(req: GraphRAGRequest): try: set_graph_config(req) - result = graph_rag_recall_func( + result = await graph_rag_recall_func( query=req.query, max_graph_items=req.max_graph_items, topk_return_results=req.topk_return_results, @@ -108,7 +110,7 @@ def graph_rag_recall_api(req: GraphRAGRequest): from hugegraph_llm.operators.hugegraph_op.graph_rag_query import GraphRAGQuery graph_rag = GraphRAGQuery() graph_rag.init_client(result) - vertex_details = graph_rag.get_vertex_details(result["match_vids"]) + vertex_details = await graph_rag.get_vertex_details(result["match_vids"]) if vertex_details: result["match_vids"] = vertex_details @@ -137,45 +139,7 @@ def graph_rag_recall_api(req: GraphRAGRequest): status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred." ) from e - @router.post("/config/graph", status_code=status.HTTP_201_CREATED) - def graph_config_api(req: GraphConfigRequest): - # Accept status code - res = apply_graph_conf(req.ip, req.port, req.name, req.user, req.pwd, req.gs, origin_call="http") - return generate_response(RAGResponse(status_code=res, message="Missing Value")) - - # TODO: restructure the implement of llm to three types, like "/config/chat_llm" - @router.post("/config/llm", status_code=status.HTTP_201_CREATED) - def llm_config_api(req: LLMConfigRequest): - llm_settings.llm_type = req.llm_type - - if req.llm_type == "openai": - res = apply_llm_conf(req.api_key, req.api_base, req.language_model, req.max_tokens, origin_call="http") - elif req.llm_type == "qianfan_wenxin": - res = apply_llm_conf(req.api_key, req.secret_key, req.language_model, None, origin_call="http") - else: - res = apply_llm_conf(req.host, req.port, req.language_model, None, origin_call="http") - return generate_response(RAGResponse(status_code=res, message="Missing Value")) - - @router.post("/config/embedding", status_code=status.HTTP_201_CREATED) - def embedding_config_api(req: LLMConfigRequest): - llm_settings.embedding_type = req.llm_type - - if req.llm_type == "openai": - res = apply_embedding_conf(req.api_key, req.api_base, req.language_model, origin_call="http") - elif req.llm_type == "qianfan_wenxin": - res = apply_embedding_conf(req.api_key, req.api_base, None, origin_call="http") - else: - res = apply_embedding_conf(req.host, req.port, req.language_model, origin_call="http") - return generate_response(RAGResponse(status_code=res, message="Missing Value")) - - @router.post("/config/rerank", status_code=status.HTTP_201_CREATED) - def rerank_config_api(req: RerankerConfigRequest): - llm_settings.reranker_type = req.reranker_type - - if req.reranker_type == "cohere": - res = apply_reranker_conf(req.api_key, req.reranker_model, req.cohere_base_url, origin_call="http") - elif req.reranker_type == "siliconflow": - res = apply_reranker_conf(req.api_key, req.reranker_model, None, origin_call="http") - else: - res = status.HTTP_501_NOT_IMPLEMENTED - return generate_response(RAGResponse(status_code=res, message="Missing Value")) + await graph_config_route(router, apply_graph_conf) + await llm_config_route(router, apply_llm_conf) + await embedding_config_route(router, apply_embedding_conf) + await rerank_config_route(router, apply_reranker_conf) diff --git a/hugegraph-llm/src/hugegraph_llm/api/stream_api.py b/hugegraph-llm/src/hugegraph_llm/api/stream_api.py new file mode 100644 index 000000000..0000a7179 --- /dev/null +++ b/hugegraph-llm/src/hugegraph_llm/api/stream_api.py @@ -0,0 +1,177 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import asyncio +import json + +from fastapi import status, APIRouter, HTTPException +from fastapi.responses import StreamingResponse + +from hugegraph_llm.api.models.rag_requests import ( + RAGRequest, + GraphRAGRequest, +) +from hugegraph_llm.config import prompt, huge_settings +from hugegraph_llm.utils.log import log + + +# pylint: disable=too-many-statements +async def stream_http_api( + router: APIRouter, + rag_answer_stream_func, + graph_rag_recall_stream_func, +): + @router.post("/rag/stream", status_code=status.HTTP_200_OK) + async def rag_answer_stream_api(req: RAGRequest): + if not req.stream: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Stream parameter must be set to True for streaming endpoint" + ) + + if req.client_config: + huge_settings.graph_ip = req.client_config.ip + huge_settings.graph_port = req.client_config.port + huge_settings.graph_name = req.client_config.name + huge_settings.graph_user = req.client_config.user + huge_settings.graph_pwd = req.client_config.pwd + huge_settings.graph_space = req.client_config.gs + + async def generate_stream(): + try: + async for chunk in rag_answer_stream_func( + text=req.query, + raw_answer=req.raw_answer, + vector_only_answer=req.vector_only, + graph_only_answer=req.graph_only, + graph_vector_answer=req.graph_vector_answer, + graph_ratio=req.graph_ratio, + rerank_method=req.rerank_method, + near_neighbor_first=req.near_neighbor_first, + gremlin_tmpl_num=req.gremlin_tmpl_num, + max_graph_items=req.max_graph_items, + topk_return_results=req.topk_return_results, + vector_dis_threshold=req.vector_dis_threshold, + topk_per_keyword=req.topk_per_keyword, + # Keep prompt params in the end + custom_related_information=req.custom_priority_info, + answer_prompt=req.answer_prompt or prompt.answer_prompt, + keywords_extract_prompt=req.keywords_extract_prompt or prompt.keywords_extract_prompt, + gremlin_prompt=req.gremlin_prompt or prompt.gremlin_generate_prompt, + ): + # Format as Server-Sent Events + data = json.dumps({ + "query": req.query, + "chunk": chunk + }) + yield f"data: {data}\n\n" + await asyncio.sleep(0.01) # Small delay to prevent overwhelming + except (ValueError, TypeError) as e: # More specific exceptions + log.error("Error in streaming RAG response: %s", e) + error_data = json.dumps({"error": str(e)}) + yield f"data: {error_data}\n\n" + except Exception as e: # pylint: disable=broad-exception-caught + # We need to catch all exceptions here to ensure proper error response + log.error("Unexpected error in streaming RAG response: %s", e) + error_data = json.dumps({"error": "An unexpected error occurred"}) + yield f"data: {error_data}\n\n" + + return StreamingResponse( + generate_stream(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + } + ) + + @router.post("/rag/graph/stream", status_code=status.HTTP_200_OK) + async def graph_rag_recall_stream_api(req: GraphRAGRequest): + if not req.stream: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Stream parameter must be set to True for streaming endpoint" + ) + + # Set graph config if provided + if req.client_config: + huge_settings.graph_ip = req.client_config.ip + huge_settings.graph_port = req.client_config.port + huge_settings.graph_name = req.client_config.name + huge_settings.graph_user = req.client_config.user + huge_settings.graph_pwd = req.client_config.pwd + huge_settings.graph_space = req.client_config.gs + + async def generate_graph_stream(): + try: + async for chunk in graph_rag_recall_stream_func( + query=req.query, + max_graph_items=req.max_graph_items, + topk_return_results=req.topk_return_results, + vector_dis_threshold=req.vector_dis_threshold, + topk_per_keyword=req.topk_per_keyword, + gremlin_tmpl_num=req.gremlin_tmpl_num, + rerank_method=req.rerank_method, + near_neighbor_first=req.near_neighbor_first, + custom_related_information=req.custom_priority_info, + gremlin_prompt=req.gremlin_prompt or prompt.gremlin_generate_prompt, + get_vertex_only=req.get_vertex_only + ): + # Handle vertex details for a get_vertex_only flag + if req.get_vertex_only and isinstance(chunk, dict) and "match_vids" in chunk: + from hugegraph_llm.operators.hugegraph_op.graph_rag_query import GraphRAGQuery + graph_rag = GraphRAGQuery() + graph_rag.init_client(chunk) + vertex_details = await graph_rag.get_vertex_details(chunk["match_vids"]) + if vertex_details: + chunk["match_vids"] = vertex_details + + if isinstance(chunk, dict): + params = [ + "query", + "keywords", + "match_vids", + "graph_result_flag", + "gremlin", + "graph_result", + "vertex_degree_list", + ] + user_result = {key: chunk[key] for key in params if key in chunk} + data = json.dumps({"graph_recall": user_result}) + else: + data = json.dumps({"graph_recall": json.dumps(chunk)}) + + yield f"data: {data}\n\n" + await asyncio.sleep(0.01) # Small delay + except TypeError as e: + log.error("TypeError in streaming graph RAG recall: %s", e) + error_data = json.dumps({"error": str(e)}) + yield f"data: {error_data}\n\n" + except Exception as e: # pylint: disable=broad-exception-caught + # We need to catch all exceptions here to ensure proper error response + log.error("Unexpected error in streaming graph RAG recall: %s", e) + error_data = json.dumps({"error": "An unexpected error occurred"}) + yield f"data: {error_data}\n\n" + + return StreamingResponse( + generate_graph_stream(), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + } + ) diff --git a/hugegraph-llm/src/tests/api/test_rag_api.py b/hugegraph-llm/src/tests/api/test_rag_api.py new file mode 100644 index 000000000..b31722cdb --- /dev/null +++ b/hugegraph-llm/src/tests/api/test_rag_api.py @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import asyncio +import unittest + +from fastapi import FastAPI, APIRouter +from fastapi.testclient import TestClient + +from hugegraph_llm.api.rag_api import rag_http_api + + +class MockAsyncFunction: + """Helper class to mock async functions""" + + def __init__(self, return_value): + self.return_value = return_value + self.called = False + self.last_args = None + self.last_kwargs = None + + async def __call__(self, *args, **kwargs): + self.called = True + self.last_args = args + self.last_kwargs = kwargs + return self.return_value + + +class TestRagApi(unittest.TestCase): + def setUp(self): + self.app = FastAPI() + self.router = APIRouter() + + # Mock RAG answer function + self.mock_rag_answer = MockAsyncFunction( + ["Test raw answer", "Test vector answer", "Test graph answer", "Test combined answer"] + ) + + # Mock graph RAG recall function + self.mock_graph_rag_recall = MockAsyncFunction({ + "query": "test query", + "keywords": ["test", "keyword"], + "match_vids": ["1", "2"], + "graph_result_flag": True, + "gremlin": "g.V().has('name', 'test')", + "graph_result": ["result1", "result2"], + "vertex_degree_list": [1, 2] + }) + + # Set up the API + loop = asyncio.get_event_loop() + loop.run_until_complete( + rag_http_api( + router=self.router, + rag_answer_func=self.mock_rag_answer, + graph_rag_recall_func=self.mock_graph_rag_recall + ) + ) + + self.app.include_router(self.router) + self.client = TestClient(self.app) + + def test_rag_answer_api(self): + """Test the /rag endpoint""" + # Prepare test request + request_data = { + "query": "test query", + "raw_answer": True, + "vector_only": True, + "graph_only": True, + "graph_vector_answer": True + } + + # Send request + response = self.client.post("/rag", json=request_data) + + # Check response + self.assertEqual(response.status_code, 200) + self.assertTrue(self.mock_rag_answer.called) + self.assertEqual(self.mock_rag_answer.last_kwargs["text"], "test query") + + # Check response content + response_data = response.json() + self.assertEqual(response_data["query"], "test query") + self.assertEqual(response_data["raw_answer"], "Test raw answer") + self.assertEqual(response_data["vector_only"], "Test vector answer") + self.assertEqual(response_data["graph_only"], "Test graph answer") + self.assertEqual(response_data["graph_vector_answer"], "Test combined answer") + + def test_graph_rag_recall_api(self): + """Test the /rag/graph endpoint""" + # Prepare test request + request_data = { + "query": "test query", + "gremlin_tmpl_num": 1, + "rerank_method": "bleu", + "near_neighbor_first": False, + "custom_priority_info": "", + "stream": False + } + + # Send request + response = self.client.post("/rag/graph", json=request_data) + + # Check response + self.assertEqual(response.status_code, 200) + self.assertTrue(self.mock_graph_rag_recall.called) + self.assertEqual(self.mock_graph_rag_recall.last_kwargs["query"], "test query") + + # Check response content + response_data = response.json() + self.assertIn("graph_recall", response_data) + graph_recall = response_data["graph_recall"] + self.assertEqual(graph_recall["query"], "test query") + self.assertListEqual(graph_recall["keywords"], ["test", "keyword"]) + self.assertListEqual(graph_recall["match_vids"], ["1", "2"]) + self.assertTrue(graph_recall["graph_result_flag"]) + self.assertEqual(graph_recall["gremlin"], "g.V().has('name', 'test')") + + +if __name__ == "__main__": + unittest.main()