Skip to content

Commit 4df3ddb

Browse files
authored
feat(query-handler): minor adjustment for the API (#1014)
1 parent 071547a commit 4df3ddb

File tree

6 files changed

+19
-10
lines changed

6 files changed

+19
-10
lines changed

python/cocoindex/flow.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -874,7 +874,10 @@ def add_query_handler(
874874
async def _handler(query: str) -> dict[str, Any]:
875875
handler_result = await async_handler(query)
876876
return {
877-
"results": dump_engine_object(handler_result.results),
877+
"results": [
878+
[(k, dump_engine_object(v)) for (k, v) in result.items()]
879+
for result in handler_result.results
880+
],
878881
"query_info": dump_engine_object(handler_result.query_info),
879882
}
880883

python/cocoindex/query_handler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import numpy as np
33
from numpy import typing as npt
44
from typing import Generic, TypeVar
5+
from .index import VectorSimilarityMetric
56

67

78
@dataclasses.dataclass
@@ -30,6 +31,7 @@ class QueryInfo:
3031
"""
3132

3233
embedding: list[float] | npt.NDArray[np.float32] | None = None
34+
similarity_metric: VectorSimilarityMetric | None = None
3335

3436

3537
R = TypeVar("R")

src/lib_context.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use crate::prelude::*;
55
use crate::builder::AnalyzedFlow;
66
use crate::execution::source_indexer::SourceIndexingContext;
77
use crate::service::error::ApiError;
8-
use crate::service::query_handler::{QueryHandler, QueryHandlerInfo};
8+
use crate::service::query_handler::{QueryHandler, QueryHandlerSpec};
99
use crate::settings;
1010
use crate::setup::ObjectSetupChange;
1111
use axum::http::StatusCode;
@@ -99,7 +99,7 @@ impl FlowExecutionContext {
9999
}
100100

101101
pub struct QueryHandlerContext {
102-
pub info: Arc<QueryHandlerInfo>,
102+
pub info: Arc<QueryHandlerSpec>,
103103
pub handler: Arc<dyn QueryHandler>,
104104
}
105105

src/py/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use crate::lib_context::{
99
use crate::ops::py_factory::{PyExportTargetFactory, PyOpArgSchema};
1010
use crate::ops::{interface::ExecutorFactory, py_factory::PyFunctionFactory, register_factory};
1111
use crate::server::{self, ServerSettings};
12-
use crate::service::query_handler::QueryHandlerInfo;
12+
use crate::service::query_handler::QueryHandlerSpec;
1313
use crate::settings::Settings;
1414
use crate::setup::{self};
1515
use pyo3::IntoPyObjectExt;
@@ -438,7 +438,7 @@ impl Flow {
438438
&self,
439439
name: String,
440440
handler: Py<PyAny>,
441-
handler_info: Pythonized<Option<QueryHandlerInfo>>,
441+
handler_info: Pythonized<Option<QueryHandlerSpec>>,
442442
) -> PyResult<()> {
443443
struct PyQueryHandler {
444444
handler: Py<PyAny>,

src/service/flows.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use crate::prelude::*;
22

33
use crate::execution::{evaluator, indexing_status, memoization, row_indexer, stats};
44
use crate::lib_context::LibContext;
5-
use crate::service::query_handler::{QueryHandlerInfo, QueryInput, QueryOutput};
5+
use crate::service::query_handler::{QueryHandlerSpec, QueryInput, QueryOutput};
66
use crate::{base::schema::FlowSchema, ops::interface::SourceExecutorReadOptions};
77
use axum::{
88
Json,
@@ -31,7 +31,7 @@ pub async fn get_flow_schema(
3131
pub struct GetFlowResponseData {
3232
flow_spec: spec::FlowInstanceSpec,
3333
data_schema: FlowSchema,
34-
query_handlers_spec: HashMap<String, Arc<QueryHandlerInfo>>,
34+
query_handlers_spec: HashMap<String, Arc<QueryHandlerSpec>>,
3535
}
3636

3737
#[derive(Serialize)]

src/service/query_handler.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
use crate::prelude::*;
1+
use crate::{
2+
base::spec::{FieldName, VectorSimilarityMetric},
3+
prelude::*,
4+
};
25

36
#[derive(Serialize, Deserialize, Default)]
47
pub struct QueryHandlerResultFields {
@@ -7,7 +10,7 @@ pub struct QueryHandlerResultFields {
710
}
811

912
#[derive(Serialize, Deserialize, Default)]
10-
pub struct QueryHandlerInfo {
13+
pub struct QueryHandlerSpec {
1114
#[serde(default)]
1215
result_fields: QueryHandlerResultFields,
1316
}
@@ -20,11 +23,12 @@ pub struct QueryInput {
2023
#[derive(Serialize, Deserialize, Default)]
2124
pub struct QueryInfo {
2225
pub embedding: Option<serde_json::Value>,
26+
pub similarity_metric: Option<VectorSimilarityMetric>,
2327
}
2428

2529
#[derive(Serialize, Deserialize)]
2630
pub struct QueryOutput {
27-
pub results: Vec<HashMap<String, serde_json::Value>>,
31+
pub results: Vec<Vec<(FieldName, serde_json::Value)>>,
2832
pub query_info: QueryInfo,
2933
}
3034

0 commit comments

Comments
 (0)