Skip to content

Commit 26bef5b

Browse files
authored
feat: Add parameters for Kubeflow pipeline engine (WIP) (docling-project#107)
Signed-off-by: Michele Dolfi <dol@zurich.ibm.com>
1 parent 40bb21d commit 26bef5b

21 files changed

+3782
-3133
lines changed

docling_serve/app.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828

2929
from docling.datamodel.base_models import DocumentStream
3030

31+
from docling_serve.datamodel.callback import (
32+
ProgressCallbackRequest,
33+
ProgressCallbackResponse,
34+
)
3135
from docling_serve.datamodel.convert import ConvertDocumentsOptions
3236
from docling_serve.datamodel.requests import (
3337
ConvertDocumentFileSourcesRequest,
@@ -45,11 +49,12 @@
4549
get_converter,
4650
get_pdf_pipeline_opts,
4751
)
48-
from docling_serve.engines import get_orchestrator
49-
from docling_serve.engines.async_local.orchestrator import (
50-
AsyncLocalOrchestrator,
51-
TaskNotFoundError,
52+
from docling_serve.engines.async_orchestrator import (
53+
BaseAsyncOrchestrator,
54+
ProgressInvalid,
5255
)
56+
from docling_serve.engines.async_orchestrator_factory import get_async_orchestrator
57+
from docling_serve.engines.base_orchestrator import TaskNotFoundError
5358
from docling_serve.helper_functions import FormDepends
5459
from docling_serve.response_preparation import process_results
5560
from docling_serve.settings import docling_serve_settings
@@ -94,7 +99,7 @@ async def lifespan(app: FastAPI):
9499
pdf_format_option = get_pdf_pipeline_opts(ConvertDocumentsOptions())
95100
get_converter(pdf_format_option)
96101

97-
orchestrator = get_orchestrator()
102+
orchestrator = get_async_orchestrator()
98103

99104
# Start the background queue processor
100105
queue_task = asyncio.create_task(orchestrator.process_queue())
@@ -308,7 +313,7 @@ async def process_file(
308313
response_model=TaskStatusResponse,
309314
)
310315
async def process_url_async(
311-
orchestrator: Annotated[AsyncLocalOrchestrator, Depends(get_orchestrator)],
316+
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
312317
conversion_request: ConvertDocumentsRequest,
313318
):
314319
task = await orchestrator.enqueue(request=conversion_request)
@@ -319,6 +324,7 @@ async def process_url_async(
319324
task_id=task.task_id,
320325
task_status=task.task_status,
321326
task_position=task_queue_position,
327+
task_meta=task.processing_meta,
322328
)
323329

324330
# Task status poll
@@ -327,7 +333,7 @@ async def process_url_async(
327333
response_model=TaskStatusResponse,
328334
)
329335
async def task_status_poll(
330-
orchestrator: Annotated[AsyncLocalOrchestrator, Depends(get_orchestrator)],
336+
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
331337
task_id: str,
332338
wait: Annotated[
333339
float, Query(help="Number of seconds to wait for a completed status.")
@@ -342,6 +348,7 @@ async def task_status_poll(
342348
task_id=task.task_id,
343349
task_status=task.task_status,
344350
task_position=task_queue_position,
351+
task_meta=task.processing_meta,
345352
)
346353

347354
# Task status websocket
@@ -350,7 +357,7 @@ async def task_status_poll(
350357
)
351358
async def task_status_ws(
352359
websocket: WebSocket,
353-
orchestrator: Annotated[AsyncLocalOrchestrator, Depends(get_orchestrator)],
360+
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
354361
task_id: str,
355362
):
356363
await websocket.accept()
@@ -375,6 +382,7 @@ async def task_status_ws(
375382
task_id=task.task_id,
376383
task_status=task.task_status,
377384
task_position=task_queue_position,
385+
task_meta=task.processing_meta,
378386
)
379387
await websocket.send_text(
380388
WebsocketMessage(
@@ -389,6 +397,7 @@ async def task_status_ws(
389397
task_id=task.task_id,
390398
task_status=task.task_status,
391399
task_position=task_queue_position,
400+
task_meta=task.processing_meta,
392401
)
393402
await websocket.send_text(
394403
WebsocketMessage(
@@ -416,7 +425,7 @@ async def task_status_ws(
416425
},
417426
)
418427
async def task_result(
419-
orchestrator: Annotated[AsyncLocalOrchestrator, Depends(get_orchestrator)],
428+
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
420429
task_id: str,
421430
):
422431
result = await orchestrator.task_result(task_id=task_id)
@@ -427,4 +436,23 @@ async def task_result(
427436
)
428437
return result
429438

439+
# Update task progress
440+
@app.post(
441+
"/v1alpha/callback/task/progress",
442+
response_model=ProgressCallbackResponse,
443+
)
444+
async def callback_task_progress(
445+
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
446+
request: ProgressCallbackRequest,
447+
):
448+
try:
449+
await orchestrator.receive_task_progress(request=request)
450+
return ProgressCallbackResponse(status="ack")
451+
except TaskNotFoundError:
452+
raise HTTPException(status_code=404, detail="Task not found.")
453+
except ProgressInvalid as err:
454+
raise HTTPException(
455+
status_code=400, detail=f"Invalid progress payload: {err}"
456+
)
457+
430458
return app
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import enum
2+
from typing import Annotated, Literal
3+
4+
from pydantic import BaseModel, Field
5+
6+
7+
class ProgressKind(str, enum.Enum):
8+
SET_NUM_DOCS = "set_num_docs"
9+
UPDATE_PROCESSED = "update_processed"
10+
11+
12+
class BaseProgress(BaseModel):
13+
kind: ProgressKind
14+
15+
16+
class ProgressSetNumDocs(BaseProgress):
17+
kind: Literal[ProgressKind.SET_NUM_DOCS] = ProgressKind.SET_NUM_DOCS
18+
19+
num_docs: int
20+
21+
22+
class SucceededDocsItem(BaseModel):
23+
source: str
24+
25+
26+
class FailedDocsItem(BaseModel):
27+
source: str
28+
error: str
29+
30+
31+
class ProgressUpdateProcessed(BaseProgress):
32+
kind: Literal[ProgressKind.UPDATE_PROCESSED] = ProgressKind.UPDATE_PROCESSED
33+
34+
num_processed: int
35+
num_succeeded: int
36+
num_failed: int
37+
38+
docs_succeeded: list[SucceededDocsItem]
39+
docs_failed: list[FailedDocsItem]
40+
41+
42+
class ProgressCallbackRequest(BaseModel):
43+
task_id: str
44+
progress: Annotated[
45+
ProgressSetNumDocs | ProgressUpdateProcessed, Field(discriminator="kind")
46+
]
47+
48+
49+
class ProgressCallbackResponse(BaseModel):
50+
status: Literal["ack"] = "ack"

docling_serve/datamodel/engines.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ class TaskStatus(str, enum.Enum):
1010

1111
class AsyncEngine(str, enum.Enum):
1212
LOCAL = "local"
13+
KFP = "kfp"

docling_serve/datamodel/kfp.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from pydantic import AnyUrl, BaseModel
2+
3+
4+
class CallbackSpec(BaseModel):
5+
url: AnyUrl
6+
headers: dict[str, str] = {}
7+
ca_cert: str = ""

docling_serve/datamodel/responses.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from docling.utils.profiling import ProfilingItem
88
from docling_core.types.doc import DoclingDocument
99

10+
from docling_serve.datamodel.task_meta import TaskProcessingMeta
11+
1012

1113
# Status
1214
class HealthCheckResponse(BaseModel):
@@ -38,6 +40,7 @@ class TaskStatusResponse(BaseModel):
3840
task_id: str
3941
task_status: str
4042
task_position: Optional[int] = None
43+
task_meta: Optional[TaskProcessingMeta] = None
4144

4245

4346
class MessageKind(str, enum.Enum):

docling_serve/datamodel/task.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,15 @@
55
from docling_serve.datamodel.engines import TaskStatus
66
from docling_serve.datamodel.requests import ConvertDocumentsRequest
77
from docling_serve.datamodel.responses import ConvertDocumentResponse
8+
from docling_serve.datamodel.task_meta import TaskProcessingMeta
89

910

1011
class Task(BaseModel):
1112
task_id: str
1213
task_status: TaskStatus = TaskStatus.PENDING
1314
request: Optional[ConvertDocumentsRequest]
1415
result: Optional[ConvertDocumentResponse] = None
16+
processing_meta: Optional[TaskProcessingMeta] = None
1517

1618
def is_completed(self) -> bool:
1719
if self.task_status in [TaskStatus.SUCCESS, TaskStatus.FAILURE]:
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from pydantic import BaseModel
2+
3+
4+
class TaskProcessingMeta(BaseModel):
5+
num_docs: int
6+
num_processed: int = 0
7+
num_succeeded: int = 0
8+
num_failed: int = 0

docling_serve/engines/__init__.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +0,0 @@
1-
from functools import lru_cache
2-
3-
from docling_serve.engines.async_local.orchestrator import AsyncLocalOrchestrator
4-
5-
6-
@lru_cache
7-
def get_orchestrator() -> AsyncLocalOrchestrator:
8-
return AsyncLocalOrchestrator()

docling_serve/engines/async_kfp/__init__.py

Whitespace-only changes.
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
# ruff: noqa: E402, UP006, UP035
2+
3+
from typing import Any, Dict, List
4+
5+
from kfp import dsl
6+
7+
PYTHON_BASE_IMAGE = "python:3.12"
8+
9+
10+
@dsl.component(
11+
base_image=PYTHON_BASE_IMAGE,
12+
packages_to_install=[
13+
"pydantic",
14+
"docling-serve @ git+https://github.com/docling-project/docling-serve@feat-kfp-engine",
15+
],
16+
pip_index_urls=["https://download.pytorch.org/whl/cpu", "https://pypi.org/simple"],
17+
)
18+
def generate_chunks(
19+
run_name: str,
20+
request: Dict[str, Any],
21+
batch_size: int,
22+
callbacks: List[Dict[str, Any]],
23+
) -> List[List[Dict[str, Any]]]:
24+
from pydantic import TypeAdapter
25+
26+
from docling_serve.datamodel.callback import (
27+
ProgressCallbackRequest,
28+
ProgressSetNumDocs,
29+
)
30+
from docling_serve.datamodel.kfp import CallbackSpec
31+
from docling_serve.engines.async_kfp.notify import notify_callbacks
32+
33+
CallbacksListType = TypeAdapter(list[CallbackSpec])
34+
35+
sources = request["http_sources"]
36+
splits = [sources[i : i + batch_size] for i in range(0, len(sources), batch_size)]
37+
38+
total = sum(len(chunk) for chunk in splits)
39+
payload = ProgressCallbackRequest(
40+
task_id=run_name, progress=ProgressSetNumDocs(num_docs=total)
41+
)
42+
notify_callbacks(
43+
payload=payload,
44+
callbacks=CallbacksListType.validate_python(callbacks),
45+
)
46+
47+
return splits
48+
49+
50+
@dsl.component(
51+
base_image=PYTHON_BASE_IMAGE,
52+
packages_to_install=[
53+
"pydantic",
54+
"docling-serve @ git+https://github.com/docling-project/docling-serve@feat-kfp-engine",
55+
],
56+
pip_index_urls=["https://download.pytorch.org/whl/cpu", "https://pypi.org/simple"],
57+
)
58+
def convert_batch(
59+
run_name: str,
60+
data_splits: List[Dict[str, Any]],
61+
request: Dict[str, Any],
62+
callbacks: List[Dict[str, Any]],
63+
output_path: dsl.OutputPath("Directory"), # type: ignore
64+
):
65+
from pathlib import Path
66+
67+
from pydantic import AnyUrl, TypeAdapter
68+
69+
from docling_serve.datamodel.callback import (
70+
FailedDocsItem,
71+
ProgressCallbackRequest,
72+
ProgressUpdateProcessed,
73+
SucceededDocsItem,
74+
)
75+
from docling_serve.datamodel.convert import ConvertDocumentsOptions
76+
from docling_serve.datamodel.kfp import CallbackSpec
77+
from docling_serve.datamodel.requests import HttpSource
78+
from docling_serve.engines.async_kfp.notify import notify_callbacks
79+
80+
CallbacksListType = TypeAdapter(list[CallbackSpec])
81+
82+
convert_options = ConvertDocumentsOptions.model_validate(request["options"])
83+
print(convert_options)
84+
85+
output_dir = Path(output_path)
86+
output_dir.mkdir(exist_ok=True, parents=True)
87+
docs_succeeded: list[SucceededDocsItem] = []
88+
docs_failed: list[FailedDocsItem] = []
89+
for source_dict in data_splits:
90+
source = HttpSource.model_validate(source_dict)
91+
filename = Path(str(AnyUrl(source.url).path)).name
92+
output_filename = output_dir / filename
93+
print(f"Writing {output_filename}")
94+
with output_filename.open("w") as f:
95+
f.write(source.model_dump_json())
96+
docs_succeeded.append(SucceededDocsItem(source=source.url))
97+
98+
payload = ProgressCallbackRequest(
99+
task_id=run_name,
100+
progress=ProgressUpdateProcessed(
101+
num_failed=len(docs_failed),
102+
num_processed=len(docs_succeeded) + len(docs_failed),
103+
num_succeeded=len(docs_succeeded),
104+
docs_succeeded=docs_succeeded,
105+
docs_failed=docs_failed,
106+
),
107+
)
108+
109+
print(payload)
110+
notify_callbacks(
111+
payload=payload,
112+
callbacks=CallbacksListType.validate_python(callbacks),
113+
)
114+
115+
116+
@dsl.pipeline()
117+
def process(
118+
batch_size: int,
119+
request: Dict[str, Any],
120+
callbacks: List[Dict[str, Any]] = [],
121+
run_name: str = "",
122+
):
123+
chunks_task = generate_chunks(
124+
run_name=run_name,
125+
request=request,
126+
batch_size=batch_size,
127+
callbacks=callbacks,
128+
)
129+
chunks_task.set_caching_options(False)
130+
131+
with dsl.ParallelFor(chunks_task.output, parallelism=4) as data_splits:
132+
convert_batch(
133+
run_name=run_name,
134+
data_splits=data_splits,
135+
request=request,
136+
callbacks=callbacks,
137+
)

0 commit comments

Comments
 (0)