Skip to content

Commit 4a70a99

Browse files
committed
✨ Now user can adjust the chunking size when adding or editing the embedding model
🔨 Change the default chunking size from 1200-1500 to 1024-1536 🎨 Improve the coding style of data_process module in sdk
1 parent 83a0e5f commit 4a70a99

File tree

9 files changed

+1580
-104
lines changed

9 files changed

+1580
-104
lines changed

backend/apps/data_process_app.py

Lines changed: 19 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -11,26 +11,24 @@
1111
ConvertStateRequest,
1212
TaskRequest,
1313
)
14-
import importlib
15-
import sys
14+
from data_process.tasks import process_and_forward, process_sync
15+
from services.data_process_service import get_data_process_service
1616

1717
logger = logging.getLogger("data_process.app")
1818

19+
# Use shared service instance
20+
service = get_data_process_service()
21+
22+
1923
@asynccontextmanager
2024
async def lifespan(app: APIRouter):
2125
# Startup
2226
try:
23-
svc_mod = sys.modules.get("services.data_process_service") or importlib.import_module(
24-
"services.data_process_service")
25-
svc = svc_mod.get_data_process_service()
26-
await svc.start()
27+
await service.start()
2728
yield
2829
finally:
2930
# Shutdown
30-
svc_mod = sys.modules.get("services.data_process_service") or importlib.import_module(
31-
"services.data_process_service")
32-
svc = svc_mod.get_data_process_service()
33-
await svc.stop()
31+
await service.stop()
3432

3533

3634
router = APIRouter(
@@ -51,9 +49,7 @@ async def create_task(request: TaskRequest, authorization: Optional[str] = Heade
5149

5250
logger.info(
5351
f"Creating task with source_type: {request.source_type}, model_id: {request.embedding_model_id}")
54-
tasks_mod = sys.modules.get(
55-
"data_process.tasks") or importlib.import_module("data_process.tasks")
56-
task_result = tasks_mod.process_and_forward.delay(
52+
task_result = process_and_forward.delay(
5753
source=request.source,
5854
source_type=request.source_type,
5955
chunking_strategy=request.chunking_strategy,
@@ -90,9 +86,7 @@ async def process_sync_endpoint(
9086
"""
9187
try:
9288
# Use the synchronous process task with high priority
93-
tasks_mod = sys.modules.get(
94-
"data_process.tasks") or importlib.import_module("data_process.tasks")
95-
task_result = tasks_mod.process_sync.apply_async(
89+
task_result = process_sync.apply_async(
9690
kwargs={
9791
'source': source,
9892
'source_type': source_type,
@@ -137,10 +131,7 @@ async def create_batch_tasks(request: BatchTaskRequest, authorization: Optional[
137131
Processing happens in the background for each file independently.
138132
"""
139133
try:
140-
svc_mod = sys.modules.get("services.data_process_service") or importlib.import_module(
141-
"services.data_process_service")
142-
svc = svc_mod.get_data_process_service()
143-
task_ids = await svc.create_batch_tasks_impl(authorization=authorization, request=request)
134+
task_ids = await service.create_batch_tasks_impl(authorization=authorization, request=request)
144135
return JSONResponse(status_code=HTTPStatus.CREATED, content={"task_ids": task_ids})
145136
except HTTPException:
146137
raise
@@ -163,16 +154,13 @@ async def load_image(url: str):
163154
"""
164155
try:
165156
# Use the service to load the image
166-
svc_mod = sys.modules.get("services.data_process_service") or importlib.import_module(
167-
"services.data_process_service")
168-
svc = svc_mod.get_data_process_service()
169-
image = await svc.load_image(url)
157+
image = await service.load_image(url)
170158

171159
if image is None:
172160
raise HTTPException(
173161
status_code=HTTPStatus.NOT_FOUND, detail="Failed to load image or image format not supported")
174162

175-
image_data, content_type = await svc.convert_to_base64(image)
163+
image_data, content_type = await service.convert_to_base64(image)
176164
return JSONResponse(status_code=HTTPStatus.OK,
177165
content={"success": True, "base64": image_data, "content_type": content_type})
178166
except HTTPException:
@@ -186,10 +174,7 @@ async def load_image(url: str):
186174
@router.get("")
187175
async def list_tasks():
188176
"""Get a list of all tasks with their basic status information"""
189-
svc_mod = sys.modules.get("services.data_process_service") or importlib.import_module(
190-
"services.data_process_service")
191-
svc = svc_mod.get_data_process_service()
192-
tasks = await svc.get_all_tasks()
177+
tasks = await service.get_all_tasks()
193178

194179
task_responses = []
195180
for task in tasks:
@@ -219,10 +204,7 @@ async def get_index_tasks(index_name: str):
219204
Returns tasks that are being processed or waiting to be processed
220205
"""
221206
try:
222-
svc_mod = sys.modules.get("services.data_process_service") or importlib.import_module(
223-
"services.data_process_service")
224-
svc = svc_mod.get_data_process_service()
225-
return await svc.get_index_tasks(index_name)
207+
return await service.get_index_tasks(index_name)
226208
except Exception as e:
227209
raise HTTPException(
228210
status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail=str(e))
@@ -231,9 +213,7 @@ async def get_index_tasks(index_name: str):
231213
@router.get("/{task_id}/details")
232214
async def get_task_details(task_id: str):
233215
"""Get detailed information about a task, including results"""
234-
utils_mod = sys.modules.get(
235-
"data_process.utils") or importlib.import_module("data_process.utils")
236-
task = await utils_mod.get_task_details(task_id)
216+
task = await service.get_task_details(task_id)
237217
if not task:
238218
raise HTTPException(status_code=HTTPStatus.NOT_FOUND,
239219
detail="Task not found")
@@ -253,8 +233,7 @@ async def filter_important_image(
253233
Returns importance score and confidence level.
254234
"""
255235
try:
256-
svc = get_data_process_service()
257-
result = await svc.filter_important_image(
236+
result = await service.filter_important_image(
258237
image_url=image_url,
259238
positive_prompt=positive_prompt,
260239
negative_prompt=negative_prompt
@@ -291,8 +270,7 @@ async def process_text_file(
291270
file_content = await file.read()
292271
filename = file.filename or "unknown_file"
293272

294-
svc = get_data_process_service()
295-
result = await svc.process_uploaded_text_file(
273+
result = await service.process_uploaded_text_file(
296274
file_content=file_content,
297275
filename=filename,
298276
chunking_strategy=chunking_strategy,
@@ -317,10 +295,7 @@ async def convert_state(request: ConvertStateRequest):
317295
This endpoint converts a process state string to a forward state string.
318296
"""
319297
try:
320-
svc_mod = sys.modules.get("services.data_process_service") or importlib.import_module(
321-
"services.data_process_service")
322-
svc = svc_mod.get_data_process_service()
323-
result = svc.convert_celery_states_to_custom(
298+
result = service.convert_celery_states_to_custom(
324299
process_celery_state=request.process_state or "",
325300
forward_celery_state=request.forward_state or ""
326301
)

docker/init.sql

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,8 @@ CREATE TABLE IF NOT EXISTS "model_record_t" (
163163
"base_url" varchar(500) COLLATE "pg_catalog"."default",
164164
"max_tokens" int4,
165165
"used_token" int4,
166+
"expected_chunk_size" int4,
167+
"maximum_chunk_size" int4,
166168
"display_name" varchar(100) COLLATE "pg_catalog"."default",
167169
"connect_status" varchar(100) COLLATE "pg_catalog"."default",
168170
"create_time" timestamp(0) DEFAULT CURRENT_TIMESTAMP,
@@ -183,6 +185,8 @@ COMMENT ON COLUMN "model_record_t"."api_key" IS 'Model API key, used for authent
183185
COMMENT ON COLUMN "model_record_t"."base_url" IS 'Base URL address, used for requesting remote model services';
184186
COMMENT ON COLUMN "model_record_t"."max_tokens" IS 'Maximum available tokens for the model';
185187
COMMENT ON COLUMN "model_record_t"."used_token" IS 'Number of tokens already used by the model in Q&A';
188+
COMMENT ON COLUMN "model_record_t".expected_chunk_size IS 'Expected chunk size for embedding models, used during document chunking';
189+
COMMENT ON COLUMN "model_record_t".maximum_chunk_size IS 'Maximum chunk size for embedding models, used during document chunking';
186190
COMMENT ON COLUMN "model_record_t"."display_name" IS 'Model name displayed directly in frontend, customized by user';
187191
COMMENT ON COLUMN "model_record_t"."connect_status" IS 'Model connectivity status from last check, optional values: "检测中"、"可用"、"不可用"';
188192
COMMENT ON COLUMN "model_record_t"."create_time" IS 'Creation time, audit field';

test/backend/app/test_data_process_app.py

Lines changed: 106 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import sys
22
import types
3-
import asyncio
43
from typing import Any, Dict, List, Optional, Tuple
4+
from http import HTTPStatus
55

66
import pytest
7-
from fastapi import FastAPI
7+
from fastapi import FastAPI, HTTPException
88
from fastapi.testclient import TestClient
99
from pydantic import BaseModel
1010

@@ -116,6 +116,11 @@ async def get_index_tasks(self, index_name: str):
116116
raise RuntimeError("oops")
117117
return [{"id": "x"}]
118118

119+
async def get_task_details(self, task_id: str):
120+
if task_id == "missing":
121+
return None
122+
return {"id": task_id, "ok": True}
123+
119124
async def filter_important_image(self, image_url: str, positive_prompt: str, negative_prompt: str):
120125
if image_url == "err":
121126
raise RuntimeError("bad")
@@ -212,7 +217,6 @@ def test_process_sync_endpoint_success():
212217
def test_process_sync_endpoint_error(monkeypatch):
213218
# Reconfigure tasks stub to raise when getting result
214219
from backend.apps import data_process_app as app_module
215-
tasks_mod = sys.modules["data_process.tasks"]
216220

217221
class _ErrResult(_DummyResult):
218222
def get(self, timeout=None):
@@ -222,7 +226,7 @@ class _PSyncErr:
222226
def apply_async(self, **kwargs):
223227
return _ErrResult("tid")
224228

225-
setattr(tasks_mod, "process_sync", _PSyncErr())
229+
monkeypatch.setattr(app_module, "process_sync", _PSyncErr(), raising=True)
226230

227231
app = _build_app()
228232
client = TestClient(app)
@@ -233,6 +237,25 @@ def apply_async(self, **kwargs):
233237
assert resp.status_code == 500
234238

235239

240+
def test_process_sync_endpoint_http_exception(monkeypatch):
241+
from backend.apps import data_process_app as app_module
242+
243+
class _PSyncHTTP:
244+
def apply_async(self, **kwargs):
245+
raise HTTPException(
246+
status_code=HTTPStatus.BAD_REQUEST, detail="bad req")
247+
248+
monkeypatch.setattr(app_module, "process_sync", _PSyncHTTP(), raising=True)
249+
250+
app = _build_app()
251+
client = TestClient(app)
252+
resp = client.post(
253+
"/tasks/process",
254+
data={"source": "/tmp/a.txt", "source_type": "local"},
255+
)
256+
assert resp.status_code == HTTPStatus.BAD_REQUEST
257+
258+
236259
def test_batch_tasks_success():
237260
app = _build_app()
238261
client = TestClient(app)
@@ -249,20 +272,37 @@ def test_batch_tasks_success():
249272

250273
def test_batch_tasks_error(monkeypatch):
251274
# Make service raise
252-
svc_mod = sys.modules["services.data_process_service"]
253-
service: _ServiceStub = svc_mod.get_data_process_service()
275+
from backend.apps import data_process_app as app_module
254276

255277
async def err(*args, **kwargs):
256278
raise RuntimeError("x")
257279

258-
service.create_batch_tasks_impl = err # type: ignore
280+
monkeypatch.setattr(app_module.service,
281+
"create_batch_tasks_impl", err, raising=True)
259282

260283
app = _build_app()
261284
client = TestClient(app)
262285
resp = client.post("/tasks/batch", json={"sources": []}, headers={"Authorization": "Bearer t"})
263286
assert resp.status_code == 500
264287

265288

289+
def test_batch_tasks_http_exception(monkeypatch):
290+
from backend.apps import data_process_app as app_module
291+
292+
async def err_http(*args, **kwargs):
293+
raise HTTPException(
294+
status_code=HTTPStatus.NOT_ACCEPTABLE, detail="bad batch")
295+
296+
monkeypatch.setattr(app_module.service,
297+
"create_batch_tasks_impl", err_http, raising=True)
298+
299+
app = _build_app()
300+
client = TestClient(app)
301+
resp = client.post(
302+
"/tasks/batch", json={"sources": []}, headers={"Authorization": "Bearer t"})
303+
assert resp.status_code == HTTPStatus.NOT_ACCEPTABLE
304+
305+
266306
def test_load_image_success_and_not_found():
267307
app = _build_app()
268308
client = TestClient(app)
@@ -274,19 +314,37 @@ def test_load_image_success_and_not_found():
274314

275315

276316
def test_load_image_internal_error(monkeypatch):
277-
svc_mod = sys.modules["services.data_process_service"]
278-
service: _ServiceStub = svc_mod.get_data_process_service()
317+
from backend.apps import data_process_app as app_module
279318

280319
async def err(url: str):
281320
raise RuntimeError("bad")
282321

283-
service.load_image = err # type: ignore
322+
monkeypatch.setattr(app_module.service, "load_image", err, raising=True)
284323
app = _build_app()
285324
client = TestClient(app)
286325
resp = client.get("/tasks/load_image", params={"url": "x"})
287326
assert resp.status_code == 500
288327

289328

329+
def test_filter_important_image_http_exception(monkeypatch):
330+
from backend.apps import data_process_app as app_module
331+
332+
async def err_http(*args, **kwargs):
333+
raise HTTPException(
334+
status_code=HTTPStatus.BAD_REQUEST, detail="bad image")
335+
336+
monkeypatch.setattr(app_module.service,
337+
"filter_important_image", err_http, raising=True)
338+
339+
app = _build_app()
340+
client = TestClient(app)
341+
resp = client.post(
342+
"/tasks/filter_important_image",
343+
data={"image_url": "u"},
344+
)
345+
assert resp.status_code == HTTPStatus.BAD_REQUEST
346+
347+
290348
def test_list_tasks():
291349
app = _build_app()
292350
client = TestClient(app)
@@ -342,19 +400,54 @@ def test_process_text_file_success_and_error(tmp_path):
342400
assert bad.status_code == 500
343401

344402

403+
def test_process_text_file_http_exception(monkeypatch):
404+
from backend.apps import data_process_app as app_module
405+
406+
async def err_http(*args, **kwargs):
407+
raise HTTPException(
408+
status_code=HTTPStatus.UNPROCESSABLE_ENTITY, detail="bad file")
409+
410+
monkeypatch.setattr(app_module.service,
411+
"process_uploaded_text_file", err_http, raising=True)
412+
413+
app = _build_app()
414+
client = TestClient(app)
415+
files = {"file": ("x.txt", b"hello", "text/plain")}
416+
resp = client.post("/tasks/process_text_file", files=files,
417+
data={"chunking_strategy": "basic"})
418+
assert resp.status_code == HTTPStatus.UNPROCESSABLE_ENTITY
419+
420+
345421
def test_convert_state_success_and_error(monkeypatch):
346422
app = _build_app()
347423
client = TestClient(app)
348424
ok = client.post("/tasks/convert_state", json={"process_state": "SUCCESS", "forward_state": "SUCCESS"})
349425
assert ok.status_code == 200 and ok.json()["state"] == "COMPLETED"
350426

351427
# Make service raise
352-
svc_mod = sys.modules["services.data_process_service"]
353-
service: _ServiceStub = svc_mod.get_data_process_service()
428+
from backend.apps import data_process_app as app_module
354429
def raise_convert(*args, **kwargs):
355430
raise RuntimeError("x")
356-
service.convert_celery_states_to_custom = raise_convert # type: ignore
431+
monkeypatch.setattr(
432+
app_module.service, "convert_celery_states_to_custom", raise_convert, raising=True)
357433
err = client.post("/tasks/convert_state", json={"process_state": "PENDING", "forward_state": ""})
358434
assert err.status_code == 500
359435

360436

437+
def test_convert_state_http_exception(monkeypatch):
438+
app = _build_app()
439+
client = TestClient(app)
440+
441+
from backend.apps import data_process_app as app_module
442+
443+
def raise_convert_http(*args, **kwargs):
444+
raise HTTPException(
445+
status_code=HTTPStatus.NOT_ACCEPTABLE, detail="bad convert")
446+
447+
monkeypatch.setattr(
448+
app_module.service, "convert_celery_states_to_custom", raise_convert_http, raising=True
449+
)
450+
451+
resp = client.post("/tasks/convert_state",
452+
json={"process_state": "PENDING", "forward_state": ""})
453+
assert resp.status_code == HTTPStatus.NOT_ACCEPTABLE

0 commit comments

Comments
 (0)