Skip to content

Commit 570d6ad

Browse files
authored
Merge branch 'main' into patches
2 parents 6485625 + 7e6a6f4 commit 570d6ad

File tree

14 files changed

+683
-129
lines changed

14 files changed

+683
-129
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
</div>
1919

2020
<div align="center">
21-
<!-- <a href="https://trendshift.io/repositories/13939" target="_blank"><img src="https://trendshift.io/api/badge/repositories/13939" alt="cocoindex-io%2Fcocoindex | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a> -->
21+
<a href="https://trendshift.io/repositories/13939" target="_blank"><img src="https://trendshift.io/api/badge/repositories/13939" alt="cocoindex-io%2Fcocoindex | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
2222
</div>
2323

2424

docs/docs/core/data_types.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ You don't need to spell out any data type explicitly when you define the flow.
1919
All you need to do is to make sure the data passed to functions and targets are compatible with them.
2020

2121
Each type in CocoIndex type system is mapped to one or multiple types in Python.
22-
When you define a [custom function](/docs/core/custom_function), you need to annotate the data types of arguments and return values.
22+
When you define a [custom function](/docs/custom_ops/custom_functions), you need to annotate the data types of arguments and return values.
2323

2424
* When you pass a Python value to the engine (e.g. return values of a custom function), a specific type annotation is required.
2525
The type annotation needs to be specific in describing the target data type, as it provides the ground truth of the data type in the flow.

docs/docusaurus.config.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ const config: Config = {
5757
from: '/core/initialization',
5858
to: '/core/settings',
5959
},
60+
{
61+
from: '/core/custom_function',
62+
to: '/custom_ops/custom_functions',
63+
},
6064
{
6165
from: '/ops/storages',
6266
to: '/ops/targets',

docs/sidebars.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ const sidebars: SidebarsConfig = {
3131
'core/settings',
3232
'core/flow_methods',
3333
'core/cli',
34-
'core/custom_function',
3534
],
3635
},
3736
{
@@ -44,6 +43,14 @@ const sidebars: SidebarsConfig = {
4443
'ops/targets',
4544
],
4645
},
46+
{
47+
type: 'category',
48+
label: 'Custom Operations',
49+
collapsed: false,
50+
items: [
51+
'custom_ops/custom_functions',
52+
],
53+
},
4754
{
4855
type: 'category',
4956
label: 'AI Support',

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,12 @@ features = ["pyo3/extension-module"]
3030

3131
[project.optional-dependencies]
3232
dev = ["pytest", "ruff", "mypy", "pre-commit"]
33+
3334
embeddings = ["sentence-transformers>=3.3.1"]
34-
all = ["cocoindex[embeddings]"]
35+
36+
# We need to repeat the dependency above to make it available for the `all` feature.
37+
# Indirect dependencies such as "cocoindex[embeddings]" will not work for local development.
38+
all = ["sentence-transformers>=3.3.1"]
3539

3640
[tool.mypy]
3741
python_version = "3.11"

python/cocoindex/convert.py

Lines changed: 107 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
Utilities to convert between Python and engine values.
33
"""
44

5+
from __future__ import annotations
6+
57
import dataclasses
68
import datetime
79
import inspect
@@ -28,6 +30,24 @@
2830
is_struct_type,
2931
)
3032

33+
class ChildFieldPath:
34+
"""Context manager to append a field to field_path on enter and pop it on exit."""
35+
36+
_field_path: list[str]
37+
_field_name: str
38+
39+
def __init__(self, field_path: list[str], field_name: str):
40+
self._field_path: list[str] = field_path
41+
self._field_name = field_name
42+
43+
def __enter__(self) -> ChildFieldPath:
44+
self._field_path.append(self._field_name)
45+
return self
46+
47+
def __exit__(self, _exc_type: Any, _exc_val: Any, _exc_tb: Any) -> None:
48+
self._field_path.pop()
49+
50+
3151
_CONVERTIBLE_KINDS = {
3252
("Float32", "Float64"),
3353
("LocalDateTime", "OffsetDateTime"),
@@ -48,7 +68,6 @@ def _encode_engine_value_core(
4868
type_variant: AnalyzedTypeInfo | None = None,
4969
) -> Any:
5070
"""Core encoding logic for converting Python values to engine values."""
51-
5271
if dataclasses.is_dataclass(value):
5372
fields = dataclasses.fields(value)
5473
return [
@@ -200,66 +219,65 @@ def make_engine_value_decoder(
200219
)
201220

202221
if src_type_kind == "Struct":
203-
return _make_engine_struct_value_decoder(
222+
return make_engine_struct_decoder(
204223
field_path,
205224
src_type["fields"],
206225
dst_type_info,
207226
)
208227

209228
if src_type_kind in TABLE_TYPES:
210-
field_path.append("[*]")
211-
engine_fields_schema = src_type["row"]["fields"]
229+
with ChildFieldPath(field_path, "[*]"):
230+
engine_fields_schema = src_type["row"]["fields"]
212231

213-
if src_type_kind == "LTable":
214-
if isinstance(dst_type_variant, AnalyzedAnyType):
215-
return _make_engine_ltable_to_list_dict_decoder(
216-
field_path, engine_fields_schema
217-
)
218-
if not isinstance(dst_type_variant, AnalyzedListType):
219-
raise ValueError(
220-
f"Type mismatch for `{''.join(field_path)}`: "
221-
f"declared `{dst_type_info.core_type}`, a list type expected"
232+
if src_type_kind == "LTable":
233+
if isinstance(dst_type_variant, AnalyzedAnyType):
234+
return _make_engine_ltable_to_list_dict_decoder(
235+
field_path, engine_fields_schema
236+
)
237+
if not isinstance(dst_type_variant, AnalyzedListType):
238+
raise ValueError(
239+
f"Type mismatch for `{''.join(field_path)}`: "
240+
f"declared `{dst_type_info.core_type}`, a list type expected"
241+
)
242+
row_decoder = make_engine_struct_decoder(
243+
field_path,
244+
engine_fields_schema,
245+
analyze_type_info(dst_type_variant.elem_type),
222246
)
223-
row_decoder = _make_engine_struct_value_decoder(
224-
field_path,
225-
engine_fields_schema,
226-
analyze_type_info(dst_type_variant.elem_type),
227-
)
228247

229-
def decode(value: Any) -> Any | None:
230-
if value is None:
231-
return None
232-
return [row_decoder(v) for v in value]
248+
def decode(value: Any) -> Any | None:
249+
if value is None:
250+
return None
251+
return [row_decoder(v) for v in value]
233252

234-
elif src_type_kind == "KTable":
235-
if isinstance(dst_type_variant, AnalyzedAnyType):
236-
return _make_engine_ktable_to_dict_dict_decoder(
237-
field_path, engine_fields_schema
253+
elif src_type_kind == "KTable":
254+
if isinstance(dst_type_variant, AnalyzedAnyType):
255+
return _make_engine_ktable_to_dict_dict_decoder(
256+
field_path, engine_fields_schema
257+
)
258+
if not isinstance(dst_type_variant, AnalyzedDictType):
259+
raise ValueError(
260+
f"Type mismatch for `{''.join(field_path)}`: "
261+
f"declared `{dst_type_info.core_type}`, a dict type expected"
262+
)
263+
264+
key_field_schema = engine_fields_schema[0]
265+
field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}")
266+
key_decoder = make_engine_value_decoder(
267+
field_path, key_field_schema["type"], dst_type_variant.key_type
238268
)
239-
if not isinstance(dst_type_variant, AnalyzedDictType):
240-
raise ValueError(
241-
f"Type mismatch for `{''.join(field_path)}`: "
242-
f"declared `{dst_type_info.core_type}`, a dict type expected"
269+
field_path.pop()
270+
value_decoder = make_engine_struct_decoder(
271+
field_path,
272+
engine_fields_schema[1:],
273+
analyze_type_info(dst_type_variant.value_type),
243274
)
244275

245-
key_field_schema = engine_fields_schema[0]
246-
field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}")
247-
key_decoder = make_engine_value_decoder(
248-
field_path, key_field_schema["type"], dst_type_variant.key_type
249-
)
250-
field_path.pop()
251-
value_decoder = _make_engine_struct_value_decoder(
252-
field_path,
253-
engine_fields_schema[1:],
254-
analyze_type_info(dst_type_variant.value_type),
255-
)
256-
257-
def decode(value: Any) -> Any | None:
258-
if value is None:
259-
return None
260-
return {key_decoder(v[0]): value_decoder(v[1:]) for v in value}
276+
def decode(value: Any) -> Any | None:
277+
if value is None:
278+
return None
279+
return {key_decoder(v[0]): value_decoder(v[1:]) for v in value}
261280

262-
field_path.pop()
263281
return decode
264282

265283
if src_type_kind == "Union":
@@ -274,22 +292,22 @@ def decode(value: Any) -> Any | None:
274292
src_type_variants = src_type["types"]
275293
decoders = []
276294
for i, src_type_variant in enumerate(src_type_variants):
277-
src_field_path = field_path + [f"[{i}]"]
278-
decoder = None
279-
for dst_type_variant in dst_type_variants:
280-
try:
281-
decoder = make_engine_value_decoder(
282-
src_field_path, src_type_variant, dst_type_variant
295+
with ChildFieldPath(field_path, f"[{i}]"):
296+
decoder = None
297+
for dst_type_variant in dst_type_variants:
298+
try:
299+
decoder = make_engine_value_decoder(
300+
field_path, src_type_variant, dst_type_variant
301+
)
302+
break
303+
except ValueError:
304+
pass
305+
if decoder is None:
306+
raise ValueError(
307+
f"Type mismatch for `{''.join(field_path)}`: "
308+
f"cannot find matched target type for source type variant {src_type_variant}"
283309
)
284-
break
285-
except ValueError:
286-
pass
287-
if decoder is None:
288-
raise ValueError(
289-
f"Type mismatch for `{''.join(field_path)}`: "
290-
f"cannot find matched target type for source type variant {src_type_variant}"
291-
)
292-
decoders.append(decoder)
310+
decoders.append(decoder)
293311
return lambda value: decoders[value[0]](value[1])
294312

295313
if isinstance(dst_type_variant, AnalyzedAnyType):
@@ -368,7 +386,7 @@ def decode_scalar(value: Any) -> Any | None:
368386
return lambda value: value
369387

370388

371-
def _make_engine_struct_value_decoder(
389+
def make_engine_struct_decoder(
372390
field_path: list[str],
373391
src_fields: list[dict[str, Any]],
374392
dst_type_info: AnalyzedTypeInfo,
@@ -426,25 +444,24 @@ def make_closure_for_value(
426444
name: str, param: inspect.Parameter
427445
) -> Callable[[list[Any]], Any]:
428446
src_idx = src_name_to_idx.get(name)
429-
if src_idx is not None:
430-
field_path.append(f".{name}")
431-
field_decoder = make_engine_value_decoder(
432-
field_path, src_fields[src_idx]["type"], param.annotation
433-
)
434-
field_path.pop()
435-
return (
436-
lambda values: field_decoder(values[src_idx])
437-
if len(values) > src_idx
438-
else param.default
439-
)
447+
with ChildFieldPath(field_path, f".{name}"):
448+
if src_idx is not None:
449+
field_decoder = make_engine_value_decoder(
450+
field_path, src_fields[src_idx]["type"], param.annotation
451+
)
452+
return (
453+
lambda values: field_decoder(values[src_idx])
454+
if len(values) > src_idx
455+
else param.default
456+
)
440457

441-
default_value = param.default
442-
if default_value is inspect.Parameter.empty:
443-
raise ValueError(
444-
f"Field without default value is missing in input: {''.join(field_path)}"
445-
)
458+
default_value = param.default
459+
if default_value is inspect.Parameter.empty:
460+
raise ValueError(
461+
f"Field without default value is missing in input: {''.join(field_path)}"
462+
)
446463

447-
return lambda _: default_value
464+
return lambda _: default_value
448465

449466
field_value_decoder = [
450467
make_closure_for_value(name, param) for (name, param) in parameters.items()
@@ -464,13 +481,12 @@ def _make_engine_struct_to_dict_decoder(
464481
field_decoders = []
465482
for i, field_schema in enumerate(src_fields):
466483
field_name = field_schema["name"]
467-
field_path.append(f".{field_name}")
468-
field_decoder = make_engine_value_decoder(
469-
field_path,
470-
field_schema["type"],
471-
Any, # Use Any for recursive decoding
472-
)
473-
field_path.pop()
484+
with ChildFieldPath(field_path, f".{field_name}"):
485+
field_decoder = make_engine_value_decoder(
486+
field_path,
487+
field_schema["type"],
488+
Any, # Use Any for recursive decoding
489+
)
474490
field_decoders.append((field_name, field_decoder))
475491

476492
def decode_to_dict(values: list[Any] | None) -> dict[str, Any] | None:
@@ -527,9 +543,10 @@ def _make_engine_ktable_to_dict_dict_decoder(
527543
value_fields_schema = src_fields[1:]
528544

529545
# Create decoders
530-
field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}")
531-
key_decoder = make_engine_value_decoder(field_path, key_field_schema["type"], Any)
532-
field_path.pop()
546+
with ChildFieldPath(field_path, f".{key_field_schema.get('name', KEY_FIELD_NAME)}"):
547+
key_decoder = make_engine_value_decoder(
548+
field_path, key_field_schema["type"], Any
549+
)
533550

534551
value_decoder = _make_engine_struct_to_dict_decoder(field_path, value_fields_schema)
535552

0 commit comments

Comments
 (0)