Skip to content

Commit 379bc30

Browse files
authored
Support Python->Rust Struct/Table type bindings. #19 (#57)
Support Python->Rust Struct/Table type bindings.
1 parent eea9fbd commit 379bc30

File tree

2 files changed

+71
-6
lines changed

2 files changed

+71
-6
lines changed

python/cocoindex/op.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""
22
Facilities for defining cocoindex operations.
33
"""
4+
import dataclasses
45
import inspect
56

67
from typing import get_type_hints, Protocol, Any, Callable, dataclass_transform
7-
from dataclasses import dataclass
88
from enum import Enum
99
from threading import Lock
1010

@@ -28,7 +28,7 @@ def __new__(mcs, name, bases, attrs, category: OpCategory | None = None):
2828
setattr(cls, '_op_category', category)
2929
else:
3030
# It's the specific class providing specific fields.
31-
cls = dataclass(cls)
31+
cls = dataclasses.dataclass(cls)
3232
return cls
3333

3434
class SourceSpec(metaclass=SpecMeta, category=OpCategory.SOURCE): # pylint: disable=too-few-public-methods
@@ -59,6 +59,14 @@ def __call__(self, spec: dict[str, Any], *args, **kwargs):
5959
result_type = executor.analyze(*args, **kwargs)
6060
return (dump_type(result_type), executor)
6161

62+
def to_engine_value(value: Any) -> Any:
63+
"""Convert a Python value to an engine value."""
64+
if dataclasses.is_dataclass(value):
65+
return [to_engine_value(getattr(value, f.name)) for f in dataclasses.fields(value)]
66+
elif isinstance(value, list) or isinstance(value, tuple):
67+
return [to_engine_value(v) for v in value]
68+
return value
69+
6270
_gpu_dispatch_lock = Lock()
6371

6472
def executor_class(gpu: bool = False, cache: bool = False, behavior_version: int | None = None) -> Callable[[type], type]:
@@ -162,9 +170,10 @@ def __call__(self, *args, **kwargs):
162170
# For now, we use a lock to ensure only one task is executed at a time.
163171
# TODO: Implement multi-processing dispatching.
164172
with _gpu_dispatch_lock:
165-
return super().__call__(*args, **kwargs)
173+
output = super().__call__(*args, **kwargs)
166174
else:
167-
return super().__call__(*args, **kwargs)
175+
output = super().__call__(*args, **kwargs)
176+
return to_engine_value(output)
168177

169178
_WrappedClass.__name__ = cls.__name__
170179

src/ops/py_factory.rs

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
1-
use std::sync::Arc;
1+
use std::{collections::BTreeMap, sync::Arc};
22

33
use axum::async_trait;
44
use blocking::unblock;
55
use futures::FutureExt;
66
use pyo3::{
77
exceptions::PyException,
88
pyclass, pymethods,
9-
types::{IntoPyDict, PyAnyMethods, PyString, PyTuple},
9+
types::{IntoPyDict, PyAnyMethods, PyList, PyString, PyTuple},
1010
Bound, IntoPyObjectExt, Py, PyAny, PyResult, Python,
1111
};
1212

1313
use crate::{
1414
base::{schema, value},
1515
builder::plan,
16+
py::IntoPyResult,
1617
};
1718
use anyhow::Result;
1819

@@ -89,6 +90,28 @@ fn basic_value_from_py_object<'py>(
8990
Ok(result)
9091
}
9192

93+
fn field_values_from_py_object<'py>(
94+
schema: &schema::StructSchema,
95+
v: &Bound<'py, PyAny>,
96+
) -> PyResult<value::FieldValues> {
97+
let list = v.extract::<Vec<Bound<'py, PyAny>>>()?;
98+
if list.len() != schema.fields.len() {
99+
return Err(PyException::new_err(format!(
100+
"struct field number mismatch, expected {}, got {}",
101+
schema.fields.len(),
102+
list.len()
103+
)));
104+
}
105+
Ok(value::FieldValues {
106+
fields: schema
107+
.fields
108+
.iter()
109+
.zip(list.into_iter())
110+
.map(|(f, v)| value_from_py_object(&f.value_type.typ, &v))
111+
.collect::<PyResult<Vec<_>>>()?,
112+
})
113+
}
114+
92115
fn value_from_py_object<'py>(
93116
typ: &schema::ValueType,
94117
v: &Bound<'py, PyAny>,
@@ -100,6 +123,39 @@ fn value_from_py_object<'py>(
100123
schema::ValueType::Basic(typ) => {
101124
value::Value::Basic(basic_value_from_py_object(typ, v)?)
102125
}
126+
schema::ValueType::Struct(schema) => {
127+
value::Value::Struct(field_values_from_py_object(schema, v)?)
128+
}
129+
schema::ValueType::Collection(schema) => {
130+
let list = v.extract::<Vec<Bound<'py, PyAny>>>()?;
131+
let values = list
132+
.into_iter()
133+
.map(|v| field_values_from_py_object(&schema.row, &v))
134+
.collect::<PyResult<Vec<_>>>()?;
135+
match schema.kind {
136+
schema::CollectionKind::Collection => {
137+
value::Value::Collection(values.into_iter().map(|v| v.into()).collect())
138+
}
139+
schema::CollectionKind::List => {
140+
value::Value::List(values.into_iter().map(|v| v.into()).collect())
141+
}
142+
schema::CollectionKind::Table => value::Value::Table(
143+
values
144+
.into_iter()
145+
.map(|v| {
146+
let mut iter = v.fields.into_iter();
147+
let key = iter.next().unwrap().to_key().into_py_result()?;
148+
Ok((
149+
key,
150+
value::ScopeValue(value::FieldValues {
151+
fields: iter.collect::<Vec<_>>(),
152+
}),
153+
))
154+
})
155+
.collect::<PyResult<BTreeMap<_, _>>>()?,
156+
),
157+
}
158+
}
103159
_ => {
104160
return Err(PyException::new_err(format!(
105161
"unsupported value type: {}",

0 commit comments

Comments
 (0)