Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/docs/ai/llm.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ cocoindex.LlmSpec(

You can find the full list of models supported by OpenRouter [here](https://openrouter.ai/models).

### vLLM
### vLLM

Install vLLM:

Expand Down Expand Up @@ -338,4 +338,4 @@ cocoindex.LlmSpec(
```

</TabItem>
</Tabs>
</Tabs>
63 changes: 43 additions & 20 deletions python/cocoindex/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import dataclasses
import datetime
import inspect
import uuid
from enum import Enum
from typing import Any, Callable, Mapping, get_origin

Expand All @@ -14,7 +13,6 @@
from .typing import (
KEY_FIELD_NAME,
TABLE_TYPES,
AnalyzedTypeInfo,
DtypeRegistry,
analyze_type_info,
encode_enriched_type,
Expand Down Expand Up @@ -74,30 +72,58 @@ def make_engine_value_decoder(
Returns:
A decoder from an engine value to a Python value.
"""

src_type_kind = src_type["kind"]

dst_type_info: AnalyzedTypeInfo | None = None
if (
dst_annotation is not None
and dst_annotation is not inspect.Parameter.empty
and dst_annotation is not Any
):
dst_type_info = analyze_type_info(dst_annotation)
if not _is_type_kind_convertible_to(src_type_kind, dst_type_info.kind):
raise ValueError(
f"Type mismatch for `{''.join(field_path)}`: "
f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})"
)

if dst_type_info is None:
dst_is_any = (
dst_annotation is None
or dst_annotation is inspect.Parameter.empty
or dst_annotation is Any
)
if dst_is_any:
if src_type_kind == "Union":
return lambda value: value[1]
if src_type_kind == "Struct" or src_type_kind in TABLE_TYPES:
raise ValueError(
f"Missing type annotation for `{''.join(field_path)}`."
f"It's required for {src_type_kind} type."
)
return lambda value: value

dst_type_info = analyze_type_info(dst_annotation)

if src_type_kind == "Union":
dst_type_variants = (
dst_type_info.union_variant_types
if dst_type_info.union_variant_types is not None
else [dst_annotation]
)
src_type_variants = src_type["types"]
decoders = []
for i, src_type_variant in enumerate(src_type_variants):
src_field_path = field_path + [f"[{i}]"]
decoder = None
for dst_type_variant in dst_type_variants:
try:
decoder = make_engine_value_decoder(
src_field_path, src_type_variant, dst_type_variant
)
break
except ValueError:
pass
if decoder is None:
raise ValueError(
f"Type mismatch for `{''.join(field_path)}`: "
f"cannot find matched target type for source type variant {src_type_variant}"
)
decoders.append(decoder)
return lambda value: decoders[value[0]](value[1])

if not _is_type_kind_convertible_to(src_type_kind, dst_type_info.kind):
raise ValueError(
f"Type mismatch for `{''.join(field_path)}`: "
f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})"
)

if dst_type_info.kind in ("Float32", "Float64", "Int64"):
dst_core_type = dst_type_info.core_type

Expand Down Expand Up @@ -196,9 +222,6 @@ def decode(value: Any) -> Any | None:
field_path.pop()
return decode

if src_type_kind == "Union":
return lambda value: value[1]

return lambda value: value


Expand Down
16 changes: 15 additions & 1 deletion python/cocoindex/tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,9 @@ def eq(a: Any, b: Any) -> bool:
)
decoder = make_engine_value_decoder([], encoded_output_type, value_type)
decoded_value = decoder(value_from_engine)
assert eq(decoded_value, value)
assert eq(decoded_value, value), (
f"{decoded_value} != {value}; {encoded_value}; {value_type}; {encoded_output_type}"
)

if other_decoded_values is not None:
for other_value, other_type in other_decoded_values:
Expand Down Expand Up @@ -613,6 +615,18 @@ def test_roundtrip_union_timedelta() -> None:
validate_full_roundtrip(value, t)


def test_roundtrip_vector_of_union() -> None:
t = list[str | int]
value = ["a", 1]
validate_full_roundtrip(value, t)


def test_roundtrip_union_with_vector() -> None:
t = NDArray[np.float32] | str
value = np.array([1.0, 2.0, 3.0], dtype=np.float32)
validate_full_roundtrip(value, t, ([1.0, 2.0, 3.0], list[float] | str))


def test_roundtrip_ltable() -> None:
t = list[Order]
value = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]
Expand Down