Skip to content

Commit 22de402

Browse files
committed
feat(query-handler): add query handler registration in engine
1 parent 5f438de commit 22de402

File tree

7 files changed

+146
-1
lines changed

7 files changed

+146
-1
lines changed

python/cocoindex/query.py

Whitespace-only changes.

src/lib_context.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use crate::prelude::*;
55
use crate::builder::AnalyzedFlow;
66
use crate::execution::source_indexer::SourceIndexingContext;
77
use crate::service::error::ApiError;
8+
use crate::service::query_handler::{QueryHandler, QueryHandlerInfo};
89
use crate::settings;
910
use crate::setup::ObjectSetupChange;
1011
use axum::http::StatusCode;
@@ -97,9 +98,15 @@ impl FlowExecutionContext {
9798
}
9899
}
99100

101+
pub struct QueryHandlerContext {
102+
pub info: Arc<QueryHandlerInfo>,
103+
pub handler: Arc<dyn QueryHandler>,
104+
}
105+
100106
pub struct FlowContext {
101107
pub flow: Arc<AnalyzedFlow>,
102108
execution_ctx: Arc<tokio::sync::RwLock<FlowExecutionContext>>,
109+
pub query_handlers: RwLock<HashMap<String, QueryHandlerContext>>,
103110
}
104111

105112
impl FlowContext {
@@ -117,6 +124,7 @@ impl FlowContext {
117124
Ok(Self {
118125
flow,
119126
execution_ctx,
127+
query_handlers: RwLock::new(HashMap::new()),
120128
})
121129
}
122130

src/py/mod.rs

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@ use crate::prelude::*;
33

44
use crate::base::schema::{FieldSchema, ValueType};
55
use crate::base::spec::{NamedSpec, OutputMode, ReactiveOpSpec, SpecFormatter};
6-
use crate::lib_context::{clear_lib_context, get_auth_registry, init_lib_context};
6+
use crate::lib_context::{
7+
QueryHandlerContext, clear_lib_context, get_auth_registry, init_lib_context,
8+
};
79
use crate::ops::py_factory::{PyExportTargetFactory, PyOpArgSchema};
810
use crate::ops::{interface::ExecutorFactory, py_factory::PyFunctionFactory, register_factory};
911
use crate::server::{self, ServerSettings};
12+
use crate::service::query_handler::QueryHandlerInfo;
1013
use crate::settings::Settings;
1114
use crate::setup::{self};
1215
use pyo3::IntoPyObjectExt;
@@ -430,6 +433,62 @@ impl Flow {
430433
};
431434
SetupChangeBundle(Arc::new(bundle))
432435
}
436+
437+
pub fn add_query_handler(&self, name: String, handler: Py<PyAny>) -> PyResult<()> {
438+
struct PyQueryHandler {
439+
handler: Py<PyAny>,
440+
}
441+
442+
#[async_trait]
443+
impl crate::service::query_handler::QueryHandler for PyQueryHandler {
444+
async fn query(
445+
&self,
446+
input: crate::service::query_handler::QueryInput,
447+
flow_ctx: &interface::FlowInstanceContext,
448+
) -> Result<crate::service::query_handler::QueryOutput> {
449+
// Call the Python async function on the flow's event loop
450+
let result_fut = Python::with_gil(|py| -> Result<_> {
451+
let handler = self.handler.clone_ref(py);
452+
// Build args: pass a dict with the query input
453+
let args = pyo3::types::PyTuple::new(py, [input.query])?;
454+
let result_coro = handler.call(py, args, None).to_result_with_py_trace(py)?;
455+
456+
let py_exec_ctx = flow_ctx
457+
.py_exec_ctx
458+
.as_ref()
459+
.ok_or_else(|| anyhow!("Python execution context is missing"))?;
460+
let task_locals = pyo3_async_runtimes::TaskLocals::new(
461+
py_exec_ctx.event_loop.bind(py).clone(),
462+
);
463+
Ok(pyo3_async_runtimes::into_future_with_locals(
464+
&task_locals,
465+
result_coro.into_bound(py),
466+
)?)
467+
})?;
468+
469+
let py_obj = result_fut.await;
470+
// Convert Python result to Rust type with proper traceback handling
471+
let output = Python::with_gil(|py| -> Result<_> {
472+
let output_any = py_obj.to_result_with_py_trace(py)?;
473+
let output: crate::py::Pythonized<crate::service::query_handler::QueryOutput> =
474+
output_any.extract(py)?;
475+
Ok(output.into_inner())
476+
})?;
477+
478+
Ok(output)
479+
}
480+
}
481+
482+
let mut handlers = self.0.query_handlers.write().unwrap();
483+
handlers.insert(
484+
name,
485+
QueryHandlerContext {
486+
info: Arc::new(QueryHandlerInfo {}),
487+
handler: Arc::new(PyQueryHandler { handler }),
488+
},
489+
);
490+
Ok(())
491+
}
433492
}
434493

435494
#[pyclass]

src/server.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,14 @@ pub async fn init_server(
6161
"/flows/{flowInstName}/data",
6262
routing::get(service::flows::evaluate_data),
6363
)
64+
.route(
65+
"/flows/{flowInstName}/queryHandlers",
66+
routing::get(service::flows::get_query_handlers),
67+
)
68+
.route(
69+
"/flows/{flowInstName}/queryHandlers/{queryHandlerName}",
70+
routing::get(service::flows::query),
71+
)
6472
.route(
6573
"/flows/{flowInstName}/rowStatus",
6674
routing::get(service::flows::get_row_indexing_status),

src/service/flows.rs

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use crate::prelude::*;
22

33
use crate::execution::{evaluator, indexing_status, memoization, row_indexer, stats};
44
use crate::lib_context::LibContext;
5+
use crate::service::query_handler::{QueryHandlerInfo, QueryInput, QueryOutput};
56
use crate::{base::schema::FlowSchema, ops::interface::SourceExecutorReadOptions};
67
use axum::{
78
Json,
@@ -255,3 +256,42 @@ pub async fn get_row_indexing_status(
255256
.await?;
256257
Ok(Json(indexing_status))
257258
}
259+
260+
pub async fn get_query_handlers(
261+
Path(flow_name): Path<String>,
262+
State(lib_context): State<Arc<LibContext>>,
263+
) -> Result<Json<HashMap<String, Arc<QueryHandlerInfo>>>, ApiError> {
264+
let flow_ctx = lib_context.get_flow_context(&flow_name)?;
265+
let query_handlers = flow_ctx.query_handlers.read().unwrap();
266+
Ok(Json(
267+
query_handlers
268+
.iter()
269+
.map(|(name, handler)| (name.clone(), handler.info.clone()))
270+
.collect(),
271+
))
272+
}
273+
274+
pub async fn query(
275+
Path((flow_name, query_handler_name)): Path<(String, String)>,
276+
Query(query): Query<QueryInput>,
277+
State(lib_context): State<Arc<LibContext>>,
278+
) -> Result<Json<QueryOutput>, ApiError> {
279+
let flow_ctx = lib_context.get_flow_context(&flow_name)?;
280+
let query_handler = {
281+
let query_handlers = flow_ctx.query_handlers.read().unwrap();
282+
query_handlers
283+
.get(&query_handler_name)
284+
.ok_or_else(|| {
285+
ApiError::new(
286+
&format!("query handler not found: {query_handler_name}"),
287+
StatusCode::BAD_REQUEST,
288+
)
289+
})?
290+
.handler
291+
.clone()
292+
};
293+
let query_output = query_handler
294+
.query(query.into(), &flow_ctx.flow.flow_instance_ctx)
295+
.await?;
296+
Ok(Json(query_output))
297+
}

src/service/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
pub(crate) mod error;
22
pub(crate) mod flows;
3+
pub(crate) mod query_handler;

src/service/query_handler.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
use crate::prelude::*;
2+
3+
#[derive(Serialize)]
4+
pub struct QueryHandlerInfo {}
5+
6+
#[derive(Serialize, Deserialize)]
7+
pub struct QueryInput {
8+
pub query: String,
9+
}
10+
11+
#[derive(Serialize, Deserialize, Default)]
12+
pub struct QueryInfo {
13+
pub embedding: Option<serde_json::Value>,
14+
}
15+
16+
#[derive(Serialize, Deserialize)]
17+
pub struct QueryOutput {
18+
pub results: Vec<IndexMap<String, serde_json::Value>>,
19+
pub query_info: QueryInfo,
20+
}
21+
22+
#[async_trait]
23+
pub trait QueryHandler: Send + Sync {
24+
async fn query(
25+
&self,
26+
input: QueryInput,
27+
flow_ctx: &interface::FlowInstanceContext,
28+
) -> Result<QueryOutput>;
29+
}

0 commit comments

Comments
 (0)