Skip to content

Commit a0d603c

Browse files
committed
feat: support Elasticsearch datasource #108
1 parent f86b2f6 commit a0d603c

File tree

5 files changed

+72
-48
lines changed

5 files changed

+72
-48
lines changed

backend/apps/chat/task/llm.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -881,7 +881,7 @@ def save_sql_data(self, data_obj: Dict[str, Any]):
881881
def finish(self):
882882
return finish_record(session=self.session, record_id=self.record.id)
883883

884-
def execute_sql(self, sql: str, tables):
884+
def execute_sql(self, sql: str):
885885
"""Execute SQL query
886886
887887
Args:
@@ -893,7 +893,7 @@ def execute_sql(self, sql: str, tables):
893893
"""
894894
SQLBotLogUtil.info(f"Executing SQL on ds_id {self.ds.id}: {sql}")
895895
try:
896-
return exec_sql(ds=self.ds, sql=sql, origin_column=False, table_name=tables)
896+
return exec_sql(ds=self.ds, sql=sql, origin_column=False)
897897
except Exception as e:
898898
if isinstance(e, ParseSQLResultError):
899899
raise e
@@ -1022,7 +1022,6 @@ def run_task(self, in_chat: bool = True):
10221022
sql = self.check_save_sql(res=full_sql_text)
10231023
else:
10241024
sql = self.check_save_sql(res=full_sql_text)
1025-
tables = []
10261025

10271026
SQLBotLogUtil.info(sql)
10281027
format_sql = sqlparse.format(sql, reindent=True)
@@ -1040,7 +1039,7 @@ def run_task(self, in_chat: bool = True):
10401039
subsql)
10411040
real_execute_sql = assistant_dynamic_sql
10421041

1043-
result = self.execute_sql(sql=real_execute_sql, tables=tables)
1042+
result = self.execute_sql(sql=real_execute_sql)
10441043
self.save_sql_data(data_obj=result)
10451044
if in_chat:
10461045
yield 'data:' + orjson.dumps({'content': 'execute-success', 'type': 'sql-data'}).decode() + '\n\n'

backend/apps/datasource/crud/datasource.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def preview(session: SessionDep, current_user: CurrentUser, id: int, data: Table
297297
sql = f"""SELECT "{'", "'.join(fields)}" FROM "{data.table.table_name}"
298298
{where}
299299
LIMIT 100"""
300-
return exec_sql(ds, sql, True, [data.table.table_name])
300+
return exec_sql(ds, sql, True)
301301

302302

303303
def fieldEnum(session: SessionDep, id: int):
@@ -313,7 +313,7 @@ def fieldEnum(session: SessionDep, id: int):
313313

314314
db = DB.get_db(ds.type)
315315
sql = f"""SELECT DISTINCT {db.prefix}{field.field_name}{db.suffix} FROM {db.prefix}{table.table_name}{db.suffix}"""
316-
res = exec_sql(ds, sql, True, [table.table_name])
316+
res = exec_sql(ds, sql, True)
317317
return [item.get(res.get('fields')[0]) for item in res.get('data')]
318318

319319

backend/apps/db/db.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from common.core.deps import Trans
2525
from common.utils.utils import SQLBotLogUtil
2626
from fastapi import HTTPException
27-
from apps.db.es_engine import get_es_connect, get_es_index, get_es_fields, get_es_data
27+
from apps.db.es_engine import get_es_connect, get_es_index, get_es_fields, get_es_data_by_http
2828

2929

3030
def get_uri(ds: CoreDatasource) -> str:
@@ -341,7 +341,7 @@ def get_fields(ds: CoreDatasource, table_name: str = None):
341341
return res_list
342342

343343

344-
def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=False, table_name=None):
344+
def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=False):
345345
while sql.endswith(';'):
346346
sql = sql[:-1]
347347

@@ -421,17 +421,16 @@ def exec_sql(ds: CoreDatasource | AssistantOutDsSchema, sql: str, origin_column=
421421
raise ParseSQLResultError(str(ex))
422422
elif ds.type == 'es':
423423
try:
424-
if table_name and table_name[0]:
425-
res, columns = get_es_data(conf, sql, table_name[0])
426-
columns = [field[0] for field in columns] if origin_column else [field[0].lower() for
427-
field in
428-
columns]
429-
result_list = [
430-
{str(columns[i]): float(value) if isinstance(value, Decimal) else value for i, value in
431-
enumerate(tuple_item)}
432-
for tuple_item in res
433-
]
434-
return {"fields": columns, "data": result_list,
435-
"sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))}
424+
res, columns = get_es_data_by_http(conf, sql)
425+
columns = [field.get('name') for field in columns] if origin_column else [field.get('name').lower() for
426+
field in
427+
columns]
428+
result_list = [
429+
{str(columns[i]): float(value) if isinstance(value, Decimal) else value for i, value in
430+
enumerate(tuple(tuple_item))}
431+
for tuple_item in res
432+
]
433+
return {"fields": columns, "data": result_list,
434+
"sql": bytes.decode(base64.b64encode(bytes(sql, 'utf-8')))}
436435
except Exception as ex:
437436
raise Exception(str(ex))

backend/apps/db/es_engine.py

Lines changed: 51 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@
22
# Date: 2025/9/9
33

44
import json
5+
from base64 import b64encode
56

67
import requests
78
from elasticsearch import Elasticsearch
8-
from fastapi import HTTPException
99

1010
from apps.datasource.models.datasource import DatasourceConf
1111

@@ -60,29 +60,55 @@ def get_es_fields(conf: DatasourceConf, table_name: str):
6060
return res
6161

6262

63-
def get_es_data(conf: DatasourceConf, sql: str, table_name: str):
64-
r = requests.post(f"{conf.host}/_sql/translate", json={"query": sql})
65-
if r.json().get('error'):
66-
print(json.dumps(r.json()))
63+
# def get_es_data(conf: DatasourceConf, sql: str, table_name: str):
64+
# r = requests.post(f"{conf.host}/_sql/translate", json={"query": sql})
65+
# if r.json().get('error'):
66+
# print(json.dumps(r.json()))
67+
#
68+
# es_client = get_es_connect(conf)
69+
# response = es_client.search(
70+
# index=table_name,
71+
# body=json.dumps(r.json())
72+
# )
73+
#
74+
# # print(response)
75+
# fields = get_es_fields(conf, table_name)
76+
# res = []
77+
# for hit in response.get('hits').get('hits'):
78+
# item = []
79+
# if 'fields' in hit:
80+
# result = hit.get('fields') # {'title': ['Python'], 'age': [30]}
81+
# for field in fields:
82+
# v = result.get(field[0])
83+
# item.append(v[0]) if v else item.append(None)
84+
# res.append(tuple(item))
85+
# # print(hit['fields']['title'][0])
86+
# # elif '_source' in hit:
87+
# # print(hit.get('_source'))
88+
# return res, fields
6789

68-
es_client = get_es_connect(conf)
69-
response = es_client.search(
70-
index=table_name,
71-
body=json.dumps(r.json())
72-
)
7390

74-
# print(response)
75-
fields = get_es_fields(conf, table_name)
76-
res = []
77-
for hit in response.get('hits').get('hits'):
78-
item = []
79-
if 'fields' in hit:
80-
result = hit.get('fields') # {'title': ['Python'], 'age': [30]}
81-
for field in fields:
82-
v = result.get(field[0])
83-
item.append(v[0]) if v else item.append(None)
84-
res.append(tuple(item))
85-
# print(hit['fields']['title'][0])
86-
# elif '_source' in hit:
87-
# print(hit.get('_source'))
88-
return res, fields
91+
def get_es_data_by_http(conf: DatasourceConf, sql: str):
92+
url = conf.host
93+
while url.endswith('/'):
94+
url = sql[:-1]
95+
96+
host = f'{url}/_sql?format=json'
97+
username = f"{conf.username}"
98+
password = f"{conf.password}"
99+
100+
credentials = f"{username}:{password}"
101+
encoded_credentials = b64encode(credentials.encode()).decode()
102+
103+
headers = {
104+
"Content-Type": "application/json",
105+
"Authorization": f"Basic {encoded_credentials}"
106+
}
107+
108+
response = requests.post(host, data=json.dumps({"query": sql}), headers=headers)
109+
110+
# print(response.json())
111+
res = response.json()
112+
fields = res.get('columns')
113+
result = res.get('rows')
114+
return result, fields

frontend/src/views/ds/js/ds-type.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import ck from '@/assets/datasource/icon_ck.png'
77
import dm from '@/assets/datasource/icon_dm.png'
88
import doris from '@/assets/datasource/icon_doris.png'
99
import redshift from '@/assets/datasource/icon_redshift.png'
10-
// import es from '@/assets/datasource/icon_es.png'
10+
import es from '@/assets/datasource/icon_es.png'
1111
import { i18n } from '@/i18n'
1212

1313
const t = i18n.global.t
@@ -21,7 +21,7 @@ export const dsType = [
2121
{ label: '达梦', value: 'dm' },
2222
{ label: 'Apache Doris', value: 'doris' },
2323
{ label: 'AWS Redshift', value: 'redshift' },
24-
// { label: 'Elasticsearch', value: 'es' },
24+
{ label: 'Elasticsearch', value: 'es' },
2525
]
2626

2727
export const dsTypeWithImg = [
@@ -34,7 +34,7 @@ export const dsTypeWithImg = [
3434
{ name: '达梦', type: 'dm', img: dm },
3535
{ name: 'Apache Doris', type: 'doris', img: doris },
3636
{ name: 'AWS Redshift', type: 'redshift', img: redshift },
37-
// { name: 'Elasticsearch', type: 'es', img: es },
37+
{ name: 'Elasticsearch', type: 'es', img: es },
3838
]
3939

4040
export const haveSchema = ['sqlServer', 'pg', 'oracle', 'dm', 'redshift']

0 commit comments

Comments
 (0)