Skip to content

Commit 106b92c

Browse files
Text2Query: Add RESTful API for database health check (opea-project#1935)
* Text2Query: Add RESTful API for database health check Signed-off-by: Yi Yao <[email protected]> --------- Signed-off-by: Yi Yao <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 71dcdcf commit 106b92c

File tree

6 files changed

+121
-4
lines changed

6 files changed

+121
-4
lines changed

comps/text2query/src/integrations/text2cypher.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,3 +215,16 @@ async def invoke(self, input: Text2QueryRequest):
215215
raise
216216

217217
return result
218+
219+
async def db_connection_check(self, request: Text2QueryRequest):
220+
"""Check the connection to Neo4j database.
221+
222+
This function takes a Text2QueryRequest object containing the database connection information.
223+
It attempts to connect to the database using the provided connection URL and credentials.
224+
225+
Args:
226+
request (Text2QueryRequest): A Text2QueryRequest object with the database connection information.
227+
Returns:
228+
dict: A dictionary with a 'status' key indicating whether the connection was successful or failed.
229+
"""
230+
return {"status": "Connection successful"}

comps/text2query/src/integrations/text2sql.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@
66
import os
77
from urllib.parse import urlparse, urlunparse
88

9+
from fastapi.exceptions import HTTPException
910
from langchain.agents.agent_types import AgentType
1011
from langchain_community.utilities.sql_database import SQLDatabase
1112
from langchain_huggingface import HuggingFaceEndpoint
13+
from sqlalchemy import create_engine
14+
from sqlalchemy.exc import SQLAlchemyError
1215

1316
from comps import CustomLogger, OpeaComponent, OpeaComponentRegistry, ServiceType
1417
from comps.cores.proto.api_protocol import Text2QueryRequest
@@ -102,7 +105,9 @@ def format_db_url(self, request: Text2QueryRequest) -> str:
102105

103106
async def invoke(self, request: Text2QueryRequest):
104107
url = request.conn_url
105-
if not url:
108+
if url:
109+
url = self.format_db_url(request)
110+
else:
106111
raise ValueError("Database connection URL must be provided in 'conn_url' field of the request.")
107112

108113
"""Execute a SQL query using the custom SQL agent.
@@ -132,3 +137,29 @@ async def invoke(self, request: Text2QueryRequest):
132137
query.append(log.tool_input)
133138
result["sql"] = query[0].replace("Observation", "")
134139
return {"result": result}
140+
141+
async def db_connection_check(self, request: Text2QueryRequest):
142+
"""Check the connection to the database.
143+
144+
This function takes a Text2QueryRequest object containing the database connection information.
145+
It attempts to connect to the database using the provided connection URL and credentials.
146+
147+
Args:
148+
request (Text2QueryRequest): A Text2QueryRequest object with the database connection information.
149+
Returns:
150+
dict: A dictionary with a 'status' key indicating whether the connection was successful or failed.
151+
"""
152+
url = request.conn_url
153+
if url:
154+
url = self.format_db_url(request)
155+
else:
156+
raise ValueError("Database connection URL must be provided in 'conn_url' field of the request.")
157+
158+
try:
159+
engine = create_engine(url)
160+
with engine.connect() as _:
161+
# If the connection is successful, return True
162+
return {"status": "Connection successful"}
163+
except SQLAlchemyError as e:
164+
logger.error(f"Connection failed: {e}")
165+
raise HTTPException(status_code=500, detail=f"Failed to connect to database: {url}")
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (C) 2024 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
from comps import OpeaComponentLoader
5+
6+
7+
class OpeaText2QueryLoader(OpeaComponentLoader):
8+
9+
async def db_connection_check(self, *args, **kwargs):
10+
return await self.component.db_connection_check(*args, **kwargs)

comps/text2query/src/opea_text2query_microservice.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,12 @@
1313

1414
import os
1515

16-
from comps import CustomLogger, OpeaComponentLoader, opea_microservices, register_microservice
16+
from fastapi import status
17+
from fastapi.exceptions import HTTPException
18+
19+
from comps import CustomLogger, opea_microservices, register_microservice
1720
from comps.cores.proto.api_protocol import Text2QueryRequest
21+
from comps.text2query.src.opea_text2query_loader import OpeaText2QueryLoader
1822

1923
logger = CustomLogger("text2query")
2024
logflag = os.getenv("LOGFLAG", False)
@@ -34,7 +38,7 @@
3438
raise ValueError(f"Unsupported TEXT2QUERY_COMPONENT_NAME: {component_name}")
3539

3640
# Initialize the OPEA component loader with the selected component
37-
loader = OpeaComponentLoader(
41+
loader = OpeaText2QueryLoader(
3842
component_name,
3943
description=f"OPEA TEXT2QUERY Component: {component_name}",
4044
)
@@ -63,6 +67,32 @@ async def execute_agent(request: Text2QueryRequest):
6367
return await loader.invoke(request)
6468

6569

70+
@register_microservice(
71+
name="opea_service@text2query",
72+
endpoint="/v1/db/health",
73+
host="0.0.0.0",
74+
port=9097,
75+
)
76+
async def db_connection_check(request: Text2QueryRequest):
77+
"""Check the connection to the database.
78+
79+
This function takes an Input object containing the database connection information.
80+
It uses the test_connection method of the PostgresConnection class to check if the connection is successful.
81+
82+
Args:
83+
request (Text2QueryRequest): An Input object with the database connection information.
84+
85+
Returns:
86+
dict: A dictionary with a 'status' key indicating whether the connection was successful or failed.
87+
"""
88+
logger.info(f"Received input for connection check: {request}")
89+
if not isinstance(request, Text2QueryRequest):
90+
raise HTTPException(
91+
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail="Input type mismatch: expected Text2QueryRequest"
92+
)
93+
return await loader.db_connection_check(request)
94+
95+
6696
if __name__ == "__main__":
6797
logger.info("OPEA Text2Query Microservice is starting...")
6898
opea_microservices["opea_service@text2query"].start()

tests/text2query/test_text2query_text2sql.sh

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,23 @@ function start_service() {
6969

7070
function validate_microservice() {
7171
url="postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@${ip_address}:5442/${POSTGRES_DB}"
72+
73+
echo "Validating v1/db/health..."
74+
result=$(http_proxy="" curl http://${ip_address}:${TEXT2SQL_PORT}/v1/db/health\
75+
-X POST \
76+
-d '{"conn_type": "sql", "conn_url": "'${url}'", "conn_user": "'${POSTGRES_USER}'","conn_password": "'${POSTGRES_PASSWORD}'","conn_dialect": "postgresql" }' \
77+
-H 'Content-Type: application/json')
78+
79+
if [[ $result == *"Connection successful"* ]]; then
80+
echo $result
81+
echo "Result correct."
82+
else
83+
echo "Result wrong. Received was $result"
84+
docker logs text2query-sql-server > ${LOG_PATH}/text2query.log
85+
exit 1
86+
fi
87+
88+
echo "Validating v1/text2query..."
7289
result=$(http_proxy="" curl http://${ip_address}:${TEXT2SQL_PORT}/v1/text2query\
7390
-X POST \
7491
-d '{"query": "Find the total number of Albums.","conn_type": "sql", "conn_url": "'${url}'", "conn_user": "'${POSTGRES_USER}'","conn_password": "'${POSTGRES_PASSWORD}'","conn_dialect": "postgresql" }' \
@@ -83,7 +100,6 @@ function validate_microservice() {
83100
docker logs tgi-server > ${LOG_PATH}/tgi.log
84101
exit 1
85102
fi
86-
87103
}
88104

89105
function stop_docker() {

tests/text2query/test_text2query_text2sql_on_intel_hpu.sh

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,23 @@ function start_service() {
6868

6969
function validate_microservice() {
7070
url="postgresql://${POSTGRES_USER}:${POSTGRES_PASSWORD}@${ip_address}:5442/${POSTGRES_DB}"
71+
72+
echo "Validating v1/db/health..."
73+
result=$(http_proxy="" curl http://${ip_address}:${TEXT2SQL_PORT}/v1/db/health\
74+
-X POST \
75+
-d '{"conn_type": "sql", "conn_url": "'${url}'", "conn_user": "'${POSTGRES_USER}'","conn_password": "'${POSTGRES_PASSWORD}'","conn_dialect": "postgresql" }' \
76+
-H 'Content-Type: application/json')
77+
78+
if [[ $result == *"Connection successful"* ]]; then
79+
echo $result
80+
echo "Result correct."
81+
else
82+
echo "Result wrong. Received was $result"
83+
docker logs text2query-sql-server > ${LOG_PATH}/text2query.log
84+
exit 1
85+
fi
86+
87+
echo "Validating v1/text2query..."
7188
result=$(http_proxy="" curl http://${ip_address}:${TEXT2SQL_PORT}/v1/text2query\
7289
-X POST \
7390
-d '{"query": "Find the total number of Albums.","conn_type": "sql", "conn_url": "'${url}'", "conn_user": "'${POSTGRES_USER}'","conn_password": "'${POSTGRES_PASSWORD}'","conn_dialect": "postgresql" }' \

0 commit comments

Comments
 (0)