Skip to content

Commit fcbf9e3

Browse files
authored
feat(query-handler): expose query handler API in Python SDK (#1011)
1 parent 85feb1a commit fcbf9e3

File tree

9 files changed

+131
-18
lines changed

9 files changed

+131
-18
lines changed

python/cocoindex/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from .index import VectorSimilarityMetric, VectorIndexDef, IndexOptions
2525
from .setting import DatabaseConnectionSpec, Settings, ServerSettings
2626
from .setting import get_app_namespace
27+
from .query_handler import QueryHandlerResultFields, QueryInfo, QueryOutput
2728
from .typing import (
2829
Int64,
2930
Float32,
@@ -95,4 +96,8 @@
9596
"Range",
9697
"Vector",
9798
"Json",
99+
# Query handler
100+
"QueryHandlerResultFields",
101+
"QueryInfo",
102+
"QueryOutput",
98103
]

python/cocoindex/convert.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,8 @@ def dump_engine_object(v: Any) -> Any:
616616
return s
617617
elif isinstance(v, (list, tuple)):
618618
return [dump_engine_object(item) for item in v]
619+
elif isinstance(v, np.ndarray):
620+
return v.tolist()
619621
elif isinstance(v, dict):
620622
return {k: dump_engine_object(v) for k, v in v.items()}
621623
return v

python/cocoindex/flow.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,10 @@
3838
make_engine_value_encoder,
3939
)
4040
from .op import FunctionSpec
41-
from .runtime import execution_context
41+
from .runtime import execution_context, to_async_call
4242
from .setup import SetupChangeBundle
4343
from .typing import analyze_type_info, encode_enriched_type
44+
from .query_handler import QueryHandlerInfo, QueryHandlerResultFields
4445
from .validation import (
4546
validate_flow_name,
4647
validate_full_flow_name,
@@ -697,13 +698,15 @@ class Flow:
697698
_engine_flow_creator: Callable[[], _engine.Flow]
698699

699700
_lazy_flow_lock: Lock
701+
_lazy_query_handler_args: list[tuple[Any, ...]]
700702
_lazy_engine_flow: _engine.Flow | None = None
701703

702704
def __init__(self, name: str, engine_flow_creator: Callable[[], _engine.Flow]):
703705
validate_flow_name(name)
704706
self._name = name
705707
self._engine_flow_creator = engine_flow_creator
706708
self._lazy_flow_lock = Lock()
709+
self._lazy_query_handler_args = []
707710

708711
def _render_spec(self, verbose: bool = False) -> Tree:
709712
"""
@@ -809,6 +812,9 @@ def _internal_flow(self) -> _engine.Flow:
809812

810813
engine_flow = self._engine_flow_creator()
811814
self._lazy_engine_flow = engine_flow
815+
for args in self._lazy_query_handler_args:
816+
engine_flow.add_query_handler(*args)
817+
self._lazy_query_handler_args = []
812818

813819
return engine_flow
814820

@@ -855,6 +861,43 @@ def close(self) -> None:
855861
with _flows_lock:
856862
del _flows[self.name]
857863

864+
def add_query_handler(
865+
self,
866+
name: str,
867+
handler: Callable[[str], Any],
868+
/,
869+
*,
870+
result_fields: QueryHandlerResultFields | None = None,
871+
) -> None:
872+
async_handler = to_async_call(handler)
873+
874+
async def _handler(query: str) -> dict[str, Any]:
875+
handler_result = await async_handler(query)
876+
return {
877+
"results": dump_engine_object(handler_result.results),
878+
"query_info": dump_engine_object(handler_result.query_info),
879+
}
880+
881+
handler_info = dump_engine_object(QueryHandlerInfo(result_fields=result_fields))
882+
with self._lazy_flow_lock:
883+
if self._lazy_engine_flow is not None:
884+
self._lazy_engine_flow.add_query_handler(name, _handler, handler_info)
885+
else:
886+
self._lazy_query_handler_args.append((name, _handler, handler_info))
887+
888+
def query_handler(
889+
self,
890+
name: str | None = None,
891+
result_fields: QueryHandlerResultFields | None = None,
892+
) -> Callable[[Callable[[str], Any]], Callable[[str], Any]]:
893+
def _inner(handler: Callable[[str], Any]) -> Callable[[str], Any]:
894+
self.add_query_handler(
895+
name or handler.__name__, handler, result_fields=result_fields
896+
)
897+
return handler
898+
899+
return _inner
900+
858901

859902
def _create_lazy_flow(
860903
name: str | None, fl_def: Callable[[FlowBuilder, DataScope], None]

python/cocoindex/op.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
Facilities for defining cocoindex operations.
33
"""
44

5-
import asyncio
65
import dataclasses
76
import inspect
87
from enum import Enum
@@ -32,6 +31,7 @@
3231
AnalyzedAnyType,
3332
AnalyzedDictType,
3433
)
34+
from .runtime import to_async_call
3535

3636

3737
class OpCategory(Enum):
@@ -150,12 +150,6 @@ class OpArgs:
150150
arg_relationship: tuple[ArgRelationship, str] | None = None
151151

152152

153-
def _to_async_call(call: Callable[..., Any]) -> Callable[..., Awaitable[Any]]:
154-
if inspect.iscoroutinefunction(call):
155-
return call
156-
return lambda *args, **kwargs: asyncio.to_thread(lambda: call(*args, **kwargs))
157-
158-
159153
@dataclasses.dataclass
160154
class _ArgInfo:
161155
decoder: Callable[[Any], Any]
@@ -319,8 +313,8 @@ async def prepare(self) -> None:
319313
"""
320314
prepare_method = getattr(self._executor, "prepare", None)
321315
if prepare_method is not None:
322-
await _to_async_call(prepare_method)()
323-
self._acall = _to_async_call(self._executor.__call__)
316+
await to_async_call(prepare_method)()
317+
self._acall = to_async_call(self._executor.__call__)
324318

325319
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
326320
decoded_args = []
@@ -461,12 +455,12 @@ def __init__(self, spec_cls: type, connector_cls: type):
461455
self._get_persistent_key_fn = _get_required_method(
462456
connector_cls, "get_persistent_key"
463457
)
464-
self._apply_setup_change_async_fn = _to_async_call(
458+
self._apply_setup_change_async_fn = to_async_call(
465459
_get_required_method(connector_cls, "apply_setup_change")
466460
)
467461

468462
mutate_fn = _get_required_method(connector_cls, "mutate")
469-
self._mutate_async_fn = _to_async_call(mutate_fn)
463+
self._mutate_async_fn = to_async_call(mutate_fn)
470464

471465
# Store the type annotation for later use
472466
self._mutatation_type = self._analyze_mutate_mutation_type(

python/cocoindex/query.py

Whitespace-only changes.

python/cocoindex/query_handler.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import dataclasses
2+
import numpy as np
3+
from numpy import typing as npt
4+
from typing import Generic, TypeVar
5+
6+
7+
@dataclasses.dataclass
8+
class QueryHandlerResultFields:
9+
"""
10+
Specify field names in the returned query handler.
11+
"""
12+
13+
embedding: list[str] = dataclasses.field(default_factory=list)
14+
score: str | None = None
15+
16+
17+
@dataclasses.dataclass
18+
class QueryHandlerInfo:
19+
"""
20+
Info to configure a query handler.
21+
"""
22+
23+
result_fields: QueryHandlerResultFields | None = None
24+
25+
26+
@dataclasses.dataclass
27+
class QueryInfo:
28+
"""
29+
Info about the query.
30+
"""
31+
32+
embedding: list[float] | npt.NDArray[np.float32] | None = None
33+
34+
35+
R = TypeVar("R")
36+
37+
38+
@dataclasses.dataclass
39+
class QueryOutput(Generic[R]):
40+
"""
41+
Output of a query handler.
42+
43+
results: list of results. Each result can be a dict or a dataclass.
44+
query_info: Info about the query.
45+
"""
46+
47+
results: list[R]
48+
query_info: QueryInfo | None = None

python/cocoindex/runtime.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55

66
import threading
77
import asyncio
8-
from typing import Any, Coroutine, TypeVar
8+
import inspect
9+
from typing import Any, Callable, Coroutine, TypeVar, Awaitable
910

1011

1112
T = TypeVar("T")
@@ -35,3 +36,9 @@ def run(self, coro: Coroutine[Any, Any, T]) -> T:
3536

3637

3738
execution_context = _ExecutionContext()
39+
40+
41+
def to_async_call(call: Callable[..., Any]) -> Callable[..., Awaitable[Any]]:
42+
if inspect.iscoroutinefunction(call):
43+
return call
44+
return lambda *args, **kwargs: asyncio.to_thread(lambda: call(*args, **kwargs))

src/py/mod.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,12 @@ impl Flow {
434434
SetupChangeBundle(Arc::new(bundle))
435435
}
436436

437-
pub fn add_query_handler(&self, name: String, handler: Py<PyAny>) -> PyResult<()> {
437+
pub fn add_query_handler(
438+
&self,
439+
name: String,
440+
handler: Py<PyAny>,
441+
handler_info: Pythonized<Option<QueryHandlerInfo>>,
442+
) -> PyResult<()> {
438443
struct PyQueryHandler {
439444
handler: Py<PyAny>,
440445
}
@@ -483,7 +488,7 @@ impl Flow {
483488
handlers.insert(
484489
name,
485490
QueryHandlerContext {
486-
info: Arc::new(QueryHandlerInfo {}),
491+
info: Arc::new(handler_info.into_inner().unwrap_or_default()),
487492
handler: Arc::new(PyQueryHandler { handler }),
488493
},
489494
);

src/service/query_handler.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,16 @@
11
use crate::prelude::*;
22

3-
#[derive(Serialize)]
4-
pub struct QueryHandlerInfo {}
3+
#[derive(Serialize, Deserialize, Default)]
4+
pub struct QueryHandlerResultFields {
5+
embedding: Vec<String>,
6+
score: Option<String>,
7+
}
8+
9+
#[derive(Serialize, Deserialize, Default)]
10+
pub struct QueryHandlerInfo {
11+
#[serde(default)]
12+
result_fields: QueryHandlerResultFields,
13+
}
514

615
#[derive(Serialize, Deserialize)]
716
pub struct QueryInput {
@@ -15,7 +24,7 @@ pub struct QueryInfo {
1524

1625
#[derive(Serialize, Deserialize)]
1726
pub struct QueryOutput {
18-
pub results: Vec<IndexMap<String, serde_json::Value>>,
27+
pub results: Vec<HashMap<String, serde_json::Value>>,
1928
pub query_info: QueryInfo,
2029
}
2130

0 commit comments

Comments
 (0)