Skip to content

Commit cfae277

Browse files
authored
Support creating dataclasses in Rust->Python value bindings. #19 (#73)
Support creating dataclasses in Rust->Python value bindings.
1 parent 5492ae7 commit cfae277

File tree

4 files changed

+123
-53
lines changed

4 files changed

+123
-53
lines changed

python/cocoindex/flow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from . import _engine
1313
from . import vector
1414
from . import op
15-
from .typing import encode_type
15+
from .typing import encode_enriched_type
1616

1717
class _NameBuilder:
1818
_existing_names: set[str]
@@ -419,7 +419,7 @@ def __init__(
419419
inspect.Parameter.KEYWORD_ONLY):
420420
raise ValueError(f"Parameter {param_name} is not a parameter can be passed by name")
421421
engine_ds = flow_builder_state.engine_flow_builder.add_direct_input(
422-
param_name, encode_type(param_type))
422+
param_name, encode_enriched_type(param_type))
423423
kwargs[param_name] = DataSlice(_DataSliceState(flow_builder_state, engine_ds))
424424

425425
output = flow_fn(**kwargs)

python/cocoindex/op.py

Lines changed: 91 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from enum import Enum
99
from threading import Lock
1010

11-
from .typing import encode_type
11+
from .typing import encode_enriched_type, analyze_type_info, COLLECTION_TYPES
1212
from . import _engine
1313

1414

@@ -57,16 +57,86 @@ def __call__(self, spec: dict[str, Any], *args, **kwargs):
5757
spec = self._spec_cls(**spec)
5858
executor = self._executor_cls(spec)
5959
result_type = executor.analyze(*args, **kwargs)
60-
return (encode_type(result_type), executor)
60+
return (encode_enriched_type(result_type), executor)
6161

6262
def _to_engine_value(value: Any) -> Any:
6363
"""Convert a Python value to an engine value."""
6464
if dataclasses.is_dataclass(value):
6565
return [_to_engine_value(getattr(value, f.name)) for f in dataclasses.fields(value)]
66-
elif isinstance(value, list) or isinstance(value, tuple):
66+
if isinstance(value, (list, tuple)):
6767
return [_to_engine_value(v) for v in value]
6868
return value
6969

70+
def _make_engine_struct_value_converter(
71+
field_path: list[str],
72+
src_fields: list[dict[str, Any]],
73+
dst_dataclass_type: type,
74+
) -> Callable[[list], Any]:
75+
"""Make a converter from an engine field values to a Python value."""
76+
77+
src_name_to_idx = {f['name']: i for i, f in enumerate(src_fields)}
78+
def make_closure_for_value(name: str, param: inspect.Parameter) -> Callable[[list], Any]:
79+
src_idx = src_name_to_idx.get(name)
80+
if src_idx is not None:
81+
field_path.append(f'.{name}')
82+
field_converter = _make_engine_value_converter(
83+
field_path, src_fields[src_idx]['type'], param.annotation)
84+
field_path.pop()
85+
return lambda values: field_converter(values[src_idx])
86+
87+
default_value = param.default
88+
if default_value is inspect.Parameter.empty:
89+
raise ValueError(
90+
f"Field without default value is missing in input: {''.join(field_path)}")
91+
92+
return lambda _: default_value
93+
94+
field_value_converters = [
95+
make_closure_for_value(name, param)
96+
for (name, param) in inspect.signature(dst_dataclass_type).parameters.items()]
97+
98+
return lambda values: dst_dataclass_type(
99+
*(converter(values) for converter in field_value_converters))
100+
101+
def _make_engine_value_converter(
102+
field_path: list[str],
103+
src_type: dict[str, Any],
104+
dst_annotation,
105+
) -> Callable[[Any], Any]:
106+
"""Make a converter from an engine value to a Python value."""
107+
108+
src_type_kind = src_type['kind']
109+
110+
if dst_annotation is inspect.Parameter.empty:
111+
if src_type_kind == 'Struct' or src_type_kind in COLLECTION_TYPES:
112+
raise ValueError(f"Missing type annotation for `{''.join(field_path)}`."
113+
f"It's required for {src_type_kind} type.")
114+
return lambda value: value
115+
116+
dst_type_info = analyze_type_info(dst_annotation)
117+
118+
if src_type_kind != dst_type_info.kind:
119+
raise ValueError(
120+
f"Type mismatch for `{''.join(field_path)}`: "
121+
f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})")
122+
123+
if dst_type_info.dataclass_type is not None:
124+
return _make_engine_struct_value_converter(
125+
field_path, src_type['fields'], dst_type_info.dataclass_type)
126+
127+
if src_type_kind in COLLECTION_TYPES:
128+
field_path.append('[*]')
129+
elem_type_info = analyze_type_info(dst_type_info.elem_type)
130+
if elem_type_info.dataclass_type is None:
131+
raise ValueError(f"Type mismatch for `{''.join(field_path)}`: "
132+
f"declared `{dst_type_info.kind}`, a dataclass type expected")
133+
elem_converter = _make_engine_struct_value_converter(
134+
field_path, src_type['row']['fields'], elem_type_info.dataclass_type)
135+
field_path.pop()
136+
return lambda value: [elem_converter(v) for v in value] if value is not None else None
137+
138+
return lambda value: value
139+
70140
_gpu_dispatch_lock = Lock()
71141

72142
def executor_class(gpu: bool = False, cache: bool = False, behavior_version: int | None = None) -> Callable[[type], type]:
@@ -105,6 +175,9 @@ def behavior_version(self):
105175
return behavior_version
106176

107177
class _WrappedClass(cls_type, _Fallback):
178+
_args_converters: list[Callable[[Any], Any]]
179+
_kwargs_converters: dict[str, Callable[[str, Any], Any]]
180+
108181
def __init__(self, spec):
109182
super().__init__()
110183
self.spec = spec
@@ -114,16 +187,19 @@ def analyze(self, *args, **kwargs):
114187
Analyze the spec and arguments. In this phase, argument types should be validated.
115188
It should return the expected result type for the current op.
116189
"""
190+
self._args_converters = []
191+
self._kwargs_converters = {}
192+
117193
# Match arguments with parameters.
118194
next_param_idx = 0
119-
for arg in args:
195+
for arg in args:
120196
if next_param_idx >= len(expected_args):
121-
raise ValueError(f"Too many arguments: {len(args)} > {len(expected_args)}")
197+
raise ValueError(f"Too many arguments passed in: {len(args)} > {len(expected_args)}")
122198
arg_name, arg_param = expected_args[next_param_idx]
123199
if arg_param.kind == inspect.Parameter.KEYWORD_ONLY or arg_param.kind == inspect.Parameter.VAR_KEYWORD:
124-
raise ValueError(f"Too many positional arguments: {len(args)} > {next_param_idx}")
125-
if arg_param.annotation is not inspect.Parameter.empty:
126-
arg.validate_arg(arg_name, encode_type(arg_param.annotation))
200+
raise ValueError(f"Too many positional arguments passed in: {len(args)} > {next_param_idx}")
201+
self._args_converters.append(
202+
_make_engine_value_converter([arg_name], arg.value_type['type'], arg_param.annotation))
127203
if arg_param.kind != inspect.Parameter.VAR_POSITIONAL:
128204
next_param_idx += 1
129205

@@ -136,10 +212,10 @@ def analyze(self, *args, **kwargs):
136212
or arg[1].kind == inspect.Parameter.VAR_KEYWORD),
137213
None)
138214
if expected_arg is None:
139-
raise ValueError(f"Unexpected keyword argument: {kwarg_name}")
215+
raise ValueError(f"Unexpected keyword argument passed in: {kwarg_name}")
140216
arg_param = expected_arg[1]
141-
if arg_param.annotation is not inspect.Parameter.empty:
142-
kwarg.validate_arg(kwarg_name, encode_type(arg_param.annotation))
217+
self._kwargs_converters[kwarg_name] = _make_engine_value_converter(
218+
[kwarg_name], kwarg.value_type['type'], arg_param.annotation)
143219

144220
missing_args = [name for (name, arg) in expected_kwargs
145221
if arg.default is inspect.Parameter.empty
@@ -164,15 +240,17 @@ def prepare(self):
164240
setup_method(self)
165241

166242
def __call__(self, *args, **kwargs):
243+
converted_args = (converter(arg) for converter, arg in zip(self._args_converters, args))
244+
converted_kwargs = {arg_name: self._kwargs_converters[arg_name](arg) for arg_name, arg in kwargs.items()}
167245
if gpu:
168246
# For GPU executions, data-level parallelism is applied, so we don't want to execute different tasks in parallel.
169247
# Besides, multiprocessing is more appropriate for pytorch.
170248
# For now, we use a lock to ensure only one task is executed at a time.
171249
# TODO: Implement multi-processing dispatching.
172250
with _gpu_dispatch_lock:
173-
output = super().__call__(*args, **kwargs)
251+
output = super().__call__(*converted_args, **converted_kwargs)
174252
else:
175-
output = super().__call__(*args, **kwargs)
253+
output = super().__call__(*converted_args, **converted_kwargs)
176254
return _to_engine_value(output)
177255

178256
_WrappedClass.__name__ = cls.__name__

python/cocoindex/typing.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ class Vector(NamedTuple):
99

1010
class TypeKind(NamedTuple):
1111
kind: str
12+
1213
class TypeAttr:
1314
key: str
1415
value: Any
@@ -24,6 +25,8 @@ def __init__(self, key: str, value: Any):
2425
Range = Annotated[tuple[int, int], TypeKind('Range')]
2526
Json = Annotated[Any, TypeKind('Json')]
2627

28+
COLLECTION_TYPES = ('Table', 'List')
29+
2730
R = TypeVar("R")
2831

2932
if TYPE_CHECKING:
@@ -46,10 +49,6 @@ class List: # type: ignore[unreachable]
4649
def __class_getitem__(cls, item: type[R]):
4750
return Annotated[list[item], TypeKind('List')]
4851

49-
def _dump_field_schema(field: dataclasses.Field) -> dict[str, Any]:
50-
encoded = _encode_enriched_type(field.type)
51-
encoded['name'] = field.name
52-
return encoded
5352
@dataclasses.dataclass
5453
class AnalyzedTypeInfo:
5554
"""
@@ -58,7 +57,7 @@ class AnalyzedTypeInfo:
5857
kind: str
5958
vector_info: Vector | None
6059
elem_type: type | None
61-
struct_fields: tuple[dataclasses.Field, ...] | None
60+
dataclass_type: type | None
6261
attrs: dict[str, Any] | None
6362
nullable: bool = False
6463

@@ -99,18 +98,18 @@ def analyze_type_info(t) -> AnalyzedTypeInfo:
9998
elif isinstance(attr, TypeKind):
10099
kind = attr.kind
101100

102-
struct_fields = None
101+
dataclass_type = None
103102
elem_type = None
104-
if dataclasses.is_dataclass(t):
103+
if isinstance(t, type) and dataclasses.is_dataclass(t):
105104
if kind is None:
106105
kind = 'Struct'
107106
elif kind != 'Struct':
108107
raise ValueError(f"Unexpected type kind for struct: {kind}")
109-
struct_fields = dataclasses.fields(t)
108+
dataclass_type = t
110109
elif base_type is collections.abc.Sequence or base_type is list:
111110
if kind is None:
112111
kind = 'Vector' if vector_info is not None else 'List'
113-
elif kind not in ('Vector', 'List', 'Table'):
112+
elif not (kind == 'Vector' or kind in COLLECTION_TYPES):
114113
raise ValueError(f"Unexpected type kind for list: {kind}")
115114

116115
args = typing.get_args(t)
@@ -134,15 +133,20 @@ def analyze_type_info(t) -> AnalyzedTypeInfo:
134133
raise ValueError(f"type unsupported yet: {base_type}")
135134

136135
return AnalyzedTypeInfo(kind=kind, vector_info=vector_info, elem_type=elem_type,
137-
struct_fields=struct_fields, attrs=attrs, nullable=nullable)
136+
dataclass_type=dataclass_type, attrs=attrs, nullable=nullable)
137+
138+
def _encode_fields_schema(dataclass_type: type) -> list[dict[str, Any]]:
139+
return [{ 'name': field.name,
140+
**encode_enriched_type_info(analyze_type_info(field.type))
141+
} for field in dataclasses.fields(dataclass_type)]
138142

139143
def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
140144
encoded_type: dict[str, Any] = { 'kind': type_info.kind }
141145

142146
if type_info.kind == 'Struct':
143-
if type_info.struct_fields is None:
144-
raise ValueError("Struct type must have a struct fields")
145-
encoded_type['fields'] = [_dump_field_schema(field) for field in type_info.struct_fields]
147+
if type_info.dataclass_type is None:
148+
raise ValueError("Struct type must have a dataclass type")
149+
encoded_type['fields'] = _encode_fields_schema(type_info.dataclass_type)
146150

147151
elif type_info.kind == 'Vector':
148152
if type_info.vector_info is None:
@@ -152,21 +156,22 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
152156
encoded_type['element_type'] = _encode_type(analyze_type_info(type_info.elem_type))
153157
encoded_type['dimension'] = type_info.vector_info.dim
154158

155-
elif type_info.kind in ('List', 'Table'):
159+
elif type_info.kind in COLLECTION_TYPES:
156160
if type_info.elem_type is None:
157161
raise ValueError(f"{type_info.kind} type must have an element type")
158-
row_type_inof = analyze_type_info(type_info.elem_type)
159-
if row_type_inof.struct_fields is None:
160-
raise ValueError(f"{type_info.kind} type must have a struct fields")
162+
row_type_info = analyze_type_info(type_info.elem_type)
163+
if row_type_info.dataclass_type is None:
164+
raise ValueError(f"{type_info.kind} type must have a dataclass type")
161165
encoded_type['row'] = {
162-
'fields': [_dump_field_schema(field) for field in row_type_inof.struct_fields],
166+
'fields': _encode_fields_schema(row_type_info.dataclass_type),
163167
}
164168

165169
return encoded_type
166170

167-
def _encode_enriched_type(t) -> dict[str, Any]:
168-
enriched_type_info = analyze_type_info(t)
169-
171+
def encode_enriched_type_info(enriched_type_info: AnalyzedTypeInfo) -> dict[str, Any]:
172+
"""
173+
Encode an enriched type info to a CocoIndex engine's type representation
174+
"""
170175
encoded: dict[str, Any] = {'type': _encode_type(enriched_type_info)}
171176

172177
if enriched_type_info.attrs is not None:
@@ -178,10 +183,11 @@ def _encode_enriched_type(t) -> dict[str, Any]:
178183
return encoded
179184

180185

181-
def encode_type(t) -> dict[str, Any] | None:
186+
def encode_enriched_type(t) -> dict[str, Any] | None:
182187
"""
183-
Convert a Python type to a CocoIndex's type in JSON.
188+
Convert a Python type to a CocoIndex engine's type representation
184189
"""
185190
if t is None:
186191
return None
187-
return _encode_enriched_type(t)
192+
193+
return encode_enriched_type_info(analyze_type_info(t))

src/ops/py_factory.rs

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -139,20 +139,6 @@ impl PyOpArgSchema {
139139
fn analyzed_value(&self) -> &crate::py::Pythonized<plan::AnalyzedValueMapping> {
140140
&self.analyzed_value
141141
}
142-
143-
fn validate_arg(
144-
&self,
145-
name: &str,
146-
typ: crate::py::Pythonized<schema::EnrichedValueType>,
147-
) -> PyResult<()> {
148-
if self.value_type.0.typ != typ.0.typ {
149-
return Err(PyException::new_err(format!(
150-
"argument `{}` type mismatch, input type: {}, argument type: {}",
151-
name, self.value_type.0.typ, typ.0.typ
152-
)));
153-
}
154-
Ok(())
155-
}
156142
}
157143

158144
struct PyFunctionExecutor {

0 commit comments

Comments
 (0)