Skip to content

Commit fca729a

Browse files
authored
Extract Rust<->Python binding logic into a module for reuse (#122)
1 parent ac0a8aa commit fca729a

File tree

3 files changed

+227
-217
lines changed

3 files changed

+227
-217
lines changed

src/ops/py_factory.rs

Lines changed: 7 additions & 172 deletions
Original file line numberDiff line numberDiff line change
@@ -1,190 +1,25 @@
1-
use std::{collections::BTreeMap, sync::Arc};
1+
use std::sync::Arc;
22

33
use axum::async_trait;
44
use futures::FutureExt;
55
use pyo3::{
6-
exceptions::PyException,
76
pyclass, pymethods,
8-
types::{IntoPyDict, PyAnyMethods, PyList, PyString, PyTuple},
9-
Bound, IntoPyObjectExt, Py, PyAny, PyResult, Python,
7+
types::{IntoPyDict, PyString, PyTuple},
8+
IntoPyObjectExt, Py, PyAny, Python,
109
};
1110
use pythonize::pythonize;
1211

1312
use crate::{
1413
base::{schema, value},
1514
builder::plan,
16-
py::IntoPyResult,
15+
py,
1716
};
1817
use anyhow::Result;
1918

2019
use super::sdk::{
2120
ExecutorFuture, FlowInstanceContext, SimpleFunctionExecutor, SimpleFunctionFactory,
2221
};
2322

24-
fn basic_value_to_py_object<'py>(
25-
py: Python<'py>,
26-
v: &value::BasicValue,
27-
) -> PyResult<Bound<'py, PyAny>> {
28-
let result = match v {
29-
value::BasicValue::Bytes(v) => v.into_bound_py_any(py)?,
30-
value::BasicValue::Str(v) => v.into_bound_py_any(py)?,
31-
value::BasicValue::Bool(v) => v.into_bound_py_any(py)?,
32-
value::BasicValue::Int64(v) => v.into_bound_py_any(py)?,
33-
value::BasicValue::Float32(v) => v.into_bound_py_any(py)?,
34-
value::BasicValue::Float64(v) => v.into_bound_py_any(py)?,
35-
value::BasicValue::Vector(v) => v
36-
.iter()
37-
.map(|v| basic_value_to_py_object(py, v))
38-
.collect::<PyResult<Vec<_>>>()?
39-
.into_bound_py_any(py)?,
40-
_ => {
41-
return Err(PyException::new_err(format!(
42-
"unsupported value type: {}",
43-
v.kind()
44-
)))
45-
}
46-
};
47-
Ok(result)
48-
}
49-
50-
fn field_values_to_py_object<'py, 'a>(
51-
py: Python<'py>,
52-
values: impl Iterator<Item = &'a value::Value>,
53-
) -> PyResult<Bound<'py, PyAny>> {
54-
let fields = values
55-
.map(|v| value_to_py_object(py, v))
56-
.collect::<PyResult<Vec<_>>>()?;
57-
Ok(PyTuple::new(py, fields)?.into_any())
58-
}
59-
60-
fn value_to_py_object<'py>(py: Python<'py>, v: &value::Value) -> PyResult<Bound<'py, PyAny>> {
61-
let result = match v {
62-
value::Value::Null => py.None().into_bound(py),
63-
value::Value::Basic(v) => basic_value_to_py_object(py, v)?,
64-
value::Value::Struct(v) => field_values_to_py_object(py, v.fields.iter())?,
65-
value::Value::Collection(v) | value::Value::List(v) => {
66-
let rows = v
67-
.iter()
68-
.map(|v| field_values_to_py_object(py, v.0.fields.iter()))
69-
.collect::<PyResult<Vec<_>>>()?;
70-
PyList::new(py, rows)?.into_any()
71-
}
72-
value::Value::Table(v) => {
73-
let rows = v
74-
.iter()
75-
.map(|(k, v)| {
76-
field_values_to_py_object(
77-
py,
78-
std::iter::once(&value::Value::from(k.clone())).chain(v.0.fields.iter()),
79-
)
80-
})
81-
.collect::<PyResult<Vec<_>>>()?;
82-
PyList::new(py, rows)?.into_any()
83-
}
84-
};
85-
Ok(result)
86-
}
87-
88-
fn basic_value_from_py_object<'py>(
89-
typ: &schema::BasicValueType,
90-
v: &Bound<'py, PyAny>,
91-
) -> PyResult<value::BasicValue> {
92-
let result = match typ {
93-
schema::BasicValueType::Bytes => {
94-
value::BasicValue::Bytes(Arc::from(v.extract::<Vec<u8>>()?))
95-
}
96-
schema::BasicValueType::Str => value::BasicValue::Str(Arc::from(v.extract::<String>()?)),
97-
schema::BasicValueType::Bool => value::BasicValue::Bool(v.extract::<bool>()?),
98-
schema::BasicValueType::Int64 => value::BasicValue::Int64(v.extract::<i64>()?),
99-
schema::BasicValueType::Float32 => value::BasicValue::Float32(v.extract::<f32>()?),
100-
schema::BasicValueType::Float64 => value::BasicValue::Float64(v.extract::<f64>()?),
101-
schema::BasicValueType::Vector(elem) => value::BasicValue::Vector(Arc::from(
102-
v.extract::<Vec<Bound<'py, PyAny>>>()?
103-
.into_iter()
104-
.map(|v| basic_value_from_py_object(&elem.element_type, &v))
105-
.collect::<PyResult<Vec<_>>>()?,
106-
)),
107-
_ => {
108-
return Err(PyException::new_err(format!(
109-
"unsupported value type: {}",
110-
typ
111-
)))
112-
}
113-
};
114-
Ok(result)
115-
}
116-
117-
fn field_values_from_py_object<'py>(
118-
schema: &schema::StructSchema,
119-
v: &Bound<'py, PyAny>,
120-
) -> PyResult<value::FieldValues> {
121-
let list = v.extract::<Vec<Bound<'py, PyAny>>>()?;
122-
if list.len() != schema.fields.len() {
123-
return Err(PyException::new_err(format!(
124-
"struct field number mismatch, expected {}, got {}",
125-
schema.fields.len(),
126-
list.len()
127-
)));
128-
}
129-
Ok(value::FieldValues {
130-
fields: schema
131-
.fields
132-
.iter()
133-
.zip(list.into_iter())
134-
.map(|(f, v)| value_from_py_object(&f.value_type.typ, &v))
135-
.collect::<PyResult<Vec<_>>>()?,
136-
})
137-
}
138-
139-
fn value_from_py_object<'py>(
140-
typ: &schema::ValueType,
141-
v: &Bound<'py, PyAny>,
142-
) -> PyResult<value::Value> {
143-
let result = if v.is_none() {
144-
value::Value::Null
145-
} else {
146-
match typ {
147-
schema::ValueType::Basic(typ) => {
148-
value::Value::Basic(basic_value_from_py_object(typ, v)?)
149-
}
150-
schema::ValueType::Struct(schema) => {
151-
value::Value::Struct(field_values_from_py_object(schema, v)?)
152-
}
153-
schema::ValueType::Collection(schema) => {
154-
let list = v.extract::<Vec<Bound<'py, PyAny>>>()?;
155-
let values = list
156-
.into_iter()
157-
.map(|v| field_values_from_py_object(&schema.row, &v))
158-
.collect::<PyResult<Vec<_>>>()?;
159-
match schema.kind {
160-
schema::CollectionKind::Collection => {
161-
value::Value::Collection(values.into_iter().map(|v| v.into()).collect())
162-
}
163-
schema::CollectionKind::List => {
164-
value::Value::List(values.into_iter().map(|v| v.into()).collect())
165-
}
166-
schema::CollectionKind::Table => value::Value::Table(
167-
values
168-
.into_iter()
169-
.map(|v| {
170-
let mut iter = v.fields.into_iter();
171-
let key = iter.next().unwrap().to_key().into_py_result()?;
172-
Ok((
173-
key,
174-
value::ScopeValue(value::FieldValues {
175-
fields: iter.collect::<Vec<_>>(),
176-
}),
177-
))
178-
})
179-
.collect::<PyResult<BTreeMap<_, _>>>()?,
180-
),
181-
}
182-
}
183-
}
184-
};
185-
Ok(result)
186-
}
187-
18823
#[pyclass(name = "OpArgSchema")]
18924
pub struct PyOpArgSchema {
19025
value_type: crate::py::Pythonized<schema::EnrichedValueType>,
@@ -222,7 +57,7 @@ impl SimpleFunctionExecutor for Arc<PyFunctionExecutor> {
22257
Python::with_gil(|py| -> Result<_> {
22358
let mut args = Vec::with_capacity(self.num_positional_args);
22459
for v in input[0..self.num_positional_args].iter() {
225-
args.push(value_to_py_object(py, v)?);
60+
args.push(py::value_to_py_object(py, v)?);
22661
}
22762

22863
let kwargs = if self.kw_args_names.is_empty() {
@@ -234,7 +69,7 @@ impl SimpleFunctionExecutor for Arc<PyFunctionExecutor> {
23469
.iter()
23570
.zip(input[self.num_positional_args..].iter())
23671
{
237-
kwargs.push((name.bind(py), value_to_py_object(py, v)?));
72+
kwargs.push((name.bind(py), py::value_to_py_object(py, v)?));
23873
}
23974
Some(kwargs)
24075
};
@@ -248,7 +83,7 @@ impl SimpleFunctionExecutor for Arc<PyFunctionExecutor> {
24883
.as_ref(),
24984
)?;
25085

251-
Ok(value_from_py_object(
86+
Ok(py::value_from_py_object(
25287
&self.result_type.typ,
25388
result.bind(py),
25489
)?)

0 commit comments

Comments
 (0)