diff --git a/docs/docs/ai/llm.mdx b/docs/docs/ai/llm.mdx index b3073d8e0..7966b0ab0 100644 --- a/docs/docs/ai/llm.mdx +++ b/docs/docs/ai/llm.mdx @@ -309,7 +309,7 @@ cocoindex.LlmSpec( You can find the full list of models supported by OpenRouter [here](https://openrouter.ai/models). -### vLLM +### vLLM Install vLLM: @@ -338,4 +338,4 @@ cocoindex.LlmSpec( ``` - \ No newline at end of file + diff --git a/python/cocoindex/convert.py b/python/cocoindex/convert.py index 648cf83f3..3d3dca8c1 100644 --- a/python/cocoindex/convert.py +++ b/python/cocoindex/convert.py @@ -5,7 +5,6 @@ import dataclasses import datetime import inspect -import uuid from enum import Enum from typing import Any, Callable, Mapping, get_origin @@ -14,7 +13,6 @@ from .typing import ( KEY_FIELD_NAME, TABLE_TYPES, - AnalyzedTypeInfo, DtypeRegistry, analyze_type_info, encode_enriched_type, @@ -74,23 +72,16 @@ def make_engine_value_decoder( Returns: A decoder from an engine value to a Python value. """ - src_type_kind = src_type["kind"] - dst_type_info: AnalyzedTypeInfo | None = None - if ( - dst_annotation is not None - and dst_annotation is not inspect.Parameter.empty - and dst_annotation is not Any - ): - dst_type_info = analyze_type_info(dst_annotation) - if not _is_type_kind_convertible_to(src_type_kind, dst_type_info.kind): - raise ValueError( - f"Type mismatch for `{''.join(field_path)}`: " - f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})" - ) - - if dst_type_info is None: + dst_is_any = ( + dst_annotation is None + or dst_annotation is inspect.Parameter.empty + or dst_annotation is Any + ) + if dst_is_any: + if src_type_kind == "Union": + return lambda value: value[1] if src_type_kind == "Struct" or src_type_kind in TABLE_TYPES: raise ValueError( f"Missing type annotation for `{''.join(field_path)}`." @@ -98,6 +89,41 @@ def make_engine_value_decoder( ) return lambda value: value + dst_type_info = analyze_type_info(dst_annotation) + + if src_type_kind == "Union": + dst_type_variants = ( + dst_type_info.union_variant_types + if dst_type_info.union_variant_types is not None + else [dst_annotation] + ) + src_type_variants = src_type["types"] + decoders = [] + for i, src_type_variant in enumerate(src_type_variants): + src_field_path = field_path + [f"[{i}]"] + decoder = None + for dst_type_variant in dst_type_variants: + try: + decoder = make_engine_value_decoder( + src_field_path, src_type_variant, dst_type_variant + ) + break + except ValueError: + pass + if decoder is None: + raise ValueError( + f"Type mismatch for `{''.join(field_path)}`: " + f"cannot find matched target type for source type variant {src_type_variant}" + ) + decoders.append(decoder) + return lambda value: decoders[value[0]](value[1]) + + if not _is_type_kind_convertible_to(src_type_kind, dst_type_info.kind): + raise ValueError( + f"Type mismatch for `{''.join(field_path)}`: " + f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})" + ) + if dst_type_info.kind in ("Float32", "Float64", "Int64"): dst_core_type = dst_type_info.core_type @@ -196,9 +222,6 @@ def decode(value: Any) -> Any | None: field_path.pop() return decode - if src_type_kind == "Union": - return lambda value: value[1] - return lambda value: value diff --git a/python/cocoindex/tests/test_convert.py b/python/cocoindex/tests/test_convert.py index acc19e779..ee9d200da 100644 --- a/python/cocoindex/tests/test_convert.py +++ b/python/cocoindex/tests/test_convert.py @@ -104,7 +104,9 @@ def eq(a: Any, b: Any) -> bool: ) decoder = make_engine_value_decoder([], encoded_output_type, value_type) decoded_value = decoder(value_from_engine) - assert eq(decoded_value, value) + assert eq(decoded_value, value), ( + f"{decoded_value} != {value}; {encoded_value}; {value_type}; {encoded_output_type}" + ) if other_decoded_values is not None: for other_value, other_type in other_decoded_values: @@ -613,6 +615,18 @@ def test_roundtrip_union_timedelta() -> None: validate_full_roundtrip(value, t) +def test_roundtrip_vector_of_union() -> None: + t = list[str | int] + value = ["a", 1] + validate_full_roundtrip(value, t) + + +def test_roundtrip_union_with_vector() -> None: + t = NDArray[np.float32] | str + value = np.array([1.0, 2.0, 3.0], dtype=np.float32) + validate_full_roundtrip(value, t, ([1.0, 2.0, 3.0], list[float] | str)) + + def test_roundtrip_ltable() -> None: t = list[Order] value = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)]