Skip to content

Commit ffcd215

Browse files
authored
Refactor: Move engine->python value conversion logic to convert.py. (#177)
1 parent 801ae8f commit ffcd215

File tree

3 files changed

+90
-82
lines changed

3 files changed

+90
-82
lines changed

examples/manuals_llm_extraction/main.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,9 @@ def manual_extraction_flow(flow_builder: cocoindex.FlowBuilder, data_scope: coco
8585
doc["markdown"] = doc["content"].transform(PdfToMarkdown())
8686
doc["module_info"] = doc["markdown"].transform(
8787
cocoindex.functions.ExtractByLlm(
88-
llm_spec=cocoindex.LlmSpec(
89-
api_type=cocoindex.LlmApiType.OLLAMA,
90-
# See the full list of models: https://ollama.com/library
91-
model="llama3.2"
92-
),
9388
# Replace by this spec below, to use OpenAI API model instead of ollama
94-
# llm_spec=cocoindex.LlmSpec(
95-
# api_type=cocoindex.LlmApiType.OPENAI, model="gpt-4o"),
89+
llm_spec=cocoindex.LlmSpec(
90+
api_type=cocoindex.LlmApiType.OPENAI, model="gpt-4o"),
9691
output_type=ModuleInfo,
9792
instruction="Please extract Python module information from the manual."))
9893
doc["module_summary"] = doc["module_info"].transform(summarize_module)

python/cocoindex/convert.py

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
Utilities to convert between Python and engine values.
33
"""
44
import dataclasses
5-
from typing import Any
5+
import inspect
6+
7+
from typing import Any, Callable
8+
from .typing import analyze_type_info, COLLECTION_TYPES
69

710
def to_engine_value(value: Any) -> Any:
811
"""Convert a Python value to an engine value."""
@@ -11,3 +14,83 @@ def to_engine_value(value: Any) -> Any:
1114
if isinstance(value, (list, tuple)):
1215
return [to_engine_value(v) for v in value]
1316
return value
17+
18+
def make_engine_value_converter(
19+
field_path: list[str],
20+
src_type: dict[str, Any],
21+
dst_annotation,
22+
) -> Callable[[Any], Any]:
23+
"""
24+
Make a converter from an engine value to a Python value.
25+
26+
Args:
27+
field_path: The path to the field in the engine value. For error messages.
28+
src_type: The type of the engine value, mapped from a `cocoindex::base::schema::ValueType`.
29+
dst_annotation: The type annotation of the Python value.
30+
31+
Returns:
32+
A converter from an engine value to a Python value.
33+
"""
34+
35+
src_type_kind = src_type['kind']
36+
37+
if dst_annotation is inspect.Parameter.empty:
38+
if src_type_kind == 'Struct' or src_type_kind in COLLECTION_TYPES:
39+
raise ValueError(f"Missing type annotation for `{''.join(field_path)}`."
40+
f"It's required for {src_type_kind} type.")
41+
return lambda value: value
42+
43+
dst_type_info = analyze_type_info(dst_annotation)
44+
45+
if src_type_kind != dst_type_info.kind:
46+
raise ValueError(
47+
f"Type mismatch for `{''.join(field_path)}`: "
48+
f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})")
49+
50+
if dst_type_info.dataclass_type is not None:
51+
return _make_engine_struct_value_converter(
52+
field_path, src_type['fields'], dst_type_info.dataclass_type)
53+
54+
if src_type_kind in COLLECTION_TYPES:
55+
field_path.append('[*]')
56+
elem_type_info = analyze_type_info(dst_type_info.elem_type)
57+
if elem_type_info.dataclass_type is None:
58+
raise ValueError(f"Type mismatch for `{''.join(field_path)}`: "
59+
f"declared `{dst_type_info.kind}`, a dataclass type expected")
60+
elem_converter = _make_engine_struct_value_converter(
61+
field_path, src_type['row']['fields'], elem_type_info.dataclass_type)
62+
field_path.pop()
63+
return lambda value: [elem_converter(v) for v in value] if value is not None else None
64+
65+
return lambda value: value
66+
67+
def _make_engine_struct_value_converter(
68+
field_path: list[str],
69+
src_fields: list[dict[str, Any]],
70+
dst_dataclass_type: type,
71+
) -> Callable[[list], Any]:
72+
"""Make a converter from an engine field values to a Python value."""
73+
74+
src_name_to_idx = {f['name']: i for i, f in enumerate(src_fields)}
75+
def make_closure_for_value(name: str, param: inspect.Parameter) -> Callable[[list], Any]:
76+
src_idx = src_name_to_idx.get(name)
77+
if src_idx is not None:
78+
field_path.append(f'.{name}')
79+
field_converter = make_engine_value_converter(
80+
field_path, src_fields[src_idx]['type'], param.annotation)
81+
field_path.pop()
82+
return lambda values: field_converter(values[src_idx])
83+
84+
default_value = param.default
85+
if default_value is inspect.Parameter.empty:
86+
raise ValueError(
87+
f"Field without default value is missing in input: {''.join(field_path)}")
88+
89+
return lambda _: default_value
90+
91+
field_value_converters = [
92+
make_closure_for_value(name, param)
93+
for (name, param) in inspect.signature(dst_dataclass_type).parameters.items()]
94+
95+
return lambda values: dst_dataclass_type(
96+
*(converter(values) for converter in field_value_converters))

python/cocoindex/op.py

Lines changed: 4 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,10 @@
88
from enum import Enum
99
from threading import Lock
1010

11-
from .typing import encode_enriched_type, analyze_type_info, COLLECTION_TYPES
12-
from .convert import to_engine_value
11+
from .typing import encode_enriched_type
12+
from .convert import to_engine_value, make_engine_value_converter
1313
from . import _engine
1414

15-
1615
class OpCategory(Enum):
1716
"""The category of the operation."""
1817
FUNCTION = "function"
@@ -60,75 +59,6 @@ def __call__(self, spec: dict[str, Any], *args, **kwargs):
6059
result_type = executor.analyze(*args, **kwargs)
6160
return (encode_enriched_type(result_type), executor)
6261

63-
def _make_engine_struct_value_converter(
64-
field_path: list[str],
65-
src_fields: list[dict[str, Any]],
66-
dst_dataclass_type: type,
67-
) -> Callable[[list], Any]:
68-
"""Make a converter from an engine field values to a Python value."""
69-
70-
src_name_to_idx = {f['name']: i for i, f in enumerate(src_fields)}
71-
def make_closure_for_value(name: str, param: inspect.Parameter) -> Callable[[list], Any]:
72-
src_idx = src_name_to_idx.get(name)
73-
if src_idx is not None:
74-
field_path.append(f'.{name}')
75-
field_converter = _make_engine_value_converter(
76-
field_path, src_fields[src_idx]['type'], param.annotation)
77-
field_path.pop()
78-
return lambda values: field_converter(values[src_idx])
79-
80-
default_value = param.default
81-
if default_value is inspect.Parameter.empty:
82-
raise ValueError(
83-
f"Field without default value is missing in input: {''.join(field_path)}")
84-
85-
return lambda _: default_value
86-
87-
field_value_converters = [
88-
make_closure_for_value(name, param)
89-
for (name, param) in inspect.signature(dst_dataclass_type).parameters.items()]
90-
91-
return lambda values: dst_dataclass_type(
92-
*(converter(values) for converter in field_value_converters))
93-
94-
def _make_engine_value_converter(
95-
field_path: list[str],
96-
src_type: dict[str, Any],
97-
dst_annotation,
98-
) -> Callable[[Any], Any]:
99-
"""Make a converter from an engine value to a Python value."""
100-
101-
src_type_kind = src_type['kind']
102-
103-
if dst_annotation is inspect.Parameter.empty:
104-
if src_type_kind == 'Struct' or src_type_kind in COLLECTION_TYPES:
105-
raise ValueError(f"Missing type annotation for `{''.join(field_path)}`."
106-
f"It's required for {src_type_kind} type.")
107-
return lambda value: value
108-
109-
dst_type_info = analyze_type_info(dst_annotation)
110-
111-
if src_type_kind != dst_type_info.kind:
112-
raise ValueError(
113-
f"Type mismatch for `{''.join(field_path)}`: "
114-
f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})")
115-
116-
if dst_type_info.dataclass_type is not None:
117-
return _make_engine_struct_value_converter(
118-
field_path, src_type['fields'], dst_type_info.dataclass_type)
119-
120-
if src_type_kind in COLLECTION_TYPES:
121-
field_path.append('[*]')
122-
elem_type_info = analyze_type_info(dst_type_info.elem_type)
123-
if elem_type_info.dataclass_type is None:
124-
raise ValueError(f"Type mismatch for `{''.join(field_path)}`: "
125-
f"declared `{dst_type_info.kind}`, a dataclass type expected")
126-
elem_converter = _make_engine_struct_value_converter(
127-
field_path, src_type['row']['fields'], elem_type_info.dataclass_type)
128-
field_path.pop()
129-
return lambda value: [elem_converter(v) for v in value] if value is not None else None
130-
131-
return lambda value: value
13262

13363
_gpu_dispatch_lock = Lock()
13464

@@ -190,7 +120,7 @@ def analyze(self, *args, **kwargs):
190120
raise ValueError(
191121
f"Too many positional arguments passed in: {len(args)} > {next_param_idx}")
192122
self._args_converters.append(
193-
_make_engine_value_converter(
123+
make_engine_value_converter(
194124
[arg_name], arg.value_type['type'], arg_param.annotation))
195125
if arg_param.kind != inspect.Parameter.VAR_POSITIONAL:
196126
next_param_idx += 1
@@ -207,7 +137,7 @@ def analyze(self, *args, **kwargs):
207137
if expected_arg is None:
208138
raise ValueError(f"Unexpected keyword argument passed in: {kwarg_name}")
209139
arg_param = expected_arg[1]
210-
self._kwargs_converters[kwarg_name] = _make_engine_value_converter(
140+
self._kwargs_converters[kwarg_name] = make_engine_value_converter(
211141
[kwarg_name], kwarg.value_type['type'], arg_param.annotation)
212142

213143
missing_args = [name for (name, arg) in expected_kwargs

0 commit comments

Comments
 (0)