diff --git a/docs/docs/core/data_types.mdx b/docs/docs/core/data_types.mdx index 88eec7b1b..f5183e7ff 100644 --- a/docs/docs/core/data_types.mdx +++ b/docs/docs/core/data_types.mdx @@ -88,8 +88,8 @@ Optionally, it can have a fixed dimension. Noted as *Vector[Type]* or *Vector[Ty It supports the following Python types: * `cocoindex.Vector[T]` or `cocoindex.Vector[T, typing.Literal[Dim]]`, e.g. `cocoindex.Vector[cocoindex.Float32]` or `cocoindex.Vector[cocoindex.Float32, typing.Literal[384]]` - * The underlying Python type is `numpy.typing.NDArray[T]` where `T` is a numpy numeric type (`numpy.int64`, `numpy.float32` or `numpy.float64`), or `list[T]` otherwise -* `numpy.typing.NDArray[T]` where `T` is a numpy numeric type + * The underlying Python type is `numpy.typing.NDArray[T]` where `T` is a numpy numeric type (`numpy.int64`, `numpy.float32` or `numpy.float64`) or array type (`numpy.typing.NDArray[T]`), or `list[T]` otherwise +* `numpy.typing.NDArray[T]` where `T` is a numpy numeric type or array type * `list[T]` diff --git a/python/cocoindex/convert.py b/python/cocoindex/convert.py index e36b59ebc..a82103236 100644 --- a/python/cocoindex/convert.py +++ b/python/cocoindex/convert.py @@ -209,12 +209,9 @@ def decode(value: Any) -> Any | None: vec_elem_decoder = None scalar_dtype = None - if ( - dst_type_variant - and is_numpy_number_type(dst_type_variant.elem_type) - and dst_type_info.base_type is np.ndarray - ): - scalar_dtype = dst_type_variant.elem_type + if dst_type_variant and dst_type_info.base_type is np.ndarray: + if is_numpy_number_type(dst_type_variant.elem_type): + scalar_dtype = dst_type_variant.elem_type else: vec_elem_decoder = make_engine_value_decoder( field_path + ["[*]"], diff --git a/python/cocoindex/tests/test_convert.py b/python/cocoindex/tests/test_convert.py index 8d913f97b..717109819 100644 --- a/python/cocoindex/tests/test_convert.py +++ b/python/cocoindex/tests/test_convert.py @@ -105,7 +105,9 @@ def eq(a: Any, b: Any) -> bool: for other_value, other_type in decoded_values: decoder = make_engine_value_decoder([], encoded_output_type, other_type) other_decoded_value = decoder(value_from_engine) - assert eq(other_decoded_value, other_value) + assert eq(other_decoded_value, other_value), ( + f"Expected {other_value} but got {other_decoded_value} for {other_type}" + ) def validate_full_roundtrip( @@ -1096,6 +1098,25 @@ def test_full_roundtrip_vector_numeric_types() -> None: validate_full_roundtrip(value_u64, Vector[np.uint64, Literal[3]]) +def test_full_roundtrip_vector_of_vector() -> None: + """Test full roundtrip for vector of vector.""" + value_f32 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32) + validate_full_roundtrip( + value_f32, + Vector[Vector[np.float32, Literal[3]], Literal[2]], + ([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], list[list[np.float32]]), + ([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], list[list[cocoindex.Float32]]), + ( + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], + list[Vector[cocoindex.Float32, Literal[3]]], + ), + ( + value_f32, + np.typing.NDArray[np.typing.NDArray[np.float32]], + ), + ) + + def test_full_roundtrip_vector_other_types() -> None: """Test full roundtrip for Vector with non-numeric basic types.""" uuid_list = [uuid.uuid4(), uuid.uuid4()] diff --git a/python/cocoindex/tests/test_typing.py b/python/cocoindex/tests/test_typing.py index 8f64071a5..a528a45ce 100644 --- a/python/cocoindex/tests/test_typing.py +++ b/python/cocoindex/tests/test_typing.py @@ -124,14 +124,6 @@ def test_vector_str() -> None: assert result.variant.vector_info == VectorInfo(dim=None) -def test_vector_complex64() -> None: - typ = Vector[np.complex64] - result = analyze_type_info(typ) - assert isinstance(result.variant, AnalyzedListType) - assert result.variant.elem_type == np.complex64 - assert result.variant.vector_info == VectorInfo(dim=None) - - def test_non_numpy_vector() -> None: typ = Vector[float, Literal[3]] result = analyze_type_info(typ) @@ -140,14 +132,6 @@ def test_non_numpy_vector() -> None: assert result.variant.vector_info == VectorInfo(dim=3) -def test_ndarray_any_dtype() -> None: - typ = NDArray[Any] - with pytest.raises( - TypeError, match="NDArray for Vector must use a concrete numpy dtype" - ): - analyze_type_info(typ) - - def test_list_of_primitives() -> None: typ = list[str] result = analyze_type_info(typ) @@ -439,9 +423,14 @@ def test_annotated_list_with_type_kind() -> None: def test_unsupported_type() -> None: - typ = set with pytest.raises( ValueError, match="Unsupported as a specific type annotation for CocoIndex data type.*: ", ): - analyze_type_info(typ) + analyze_type_info(set) + + with pytest.raises( + ValueError, + match="Unsupported as a specific type annotation for CocoIndex data type.*: ", + ): + Vector[np.complex64] diff --git a/python/cocoindex/typing.py b/python/cocoindex/typing.py index cb0ae8872..041b782e1 100644 --- a/python/cocoindex/typing.py +++ b/python/cocoindex/typing.py @@ -67,10 +67,7 @@ def __class_getitem__(self, params): if not isinstance(params, tuple): # No dimension provided, e.g., Vector[np.float32] dtype = params - # Use NDArray for supported numeric dtypes, else list - if dtype in DtypeRegistry._DTYPE_TO_KIND: - return Annotated[NDArray[dtype], VectorInfo(dim=None)] - return Annotated[list[dtype], VectorInfo(dim=None)] + vector_info = VectorInfo(dim=None) else: # Element type and dimension provided, e.g., Vector[np.float32, Literal[3]] dtype, dim_literal = params @@ -80,16 +77,20 @@ def __class_getitem__(self, params): if typing.get_origin(dim_literal) is Literal else None ) - if dtype in DtypeRegistry._DTYPE_TO_KIND: - return Annotated[NDArray[dtype], VectorInfo(dim=dim_val)] - return Annotated[list[dtype], VectorInfo(dim=dim_val)] + vector_info = VectorInfo(dim=dim_val) + + # Use NDArray for supported numeric dtypes, else list + base_type = analyze_type_info(dtype).base_type + if is_numpy_number_type(base_type) or base_type is np.ndarray: + return Annotated[NDArray[dtype], vector_info] + return Annotated[list[dtype], vector_info] TABLE_TYPES: tuple[str, str] = ("KTable", "LTable") KEY_FIELD_NAME: str = "_key" -def extract_ndarray_scalar_dtype(ndarray_type: Any) -> Any: +def extract_ndarray_elem_dtype(ndarray_type: Any) -> Any: args = typing.get_args(ndarray_type) _, dtype_spec = args dtype_args = typing.get_args(dtype_spec) @@ -99,7 +100,7 @@ def extract_ndarray_scalar_dtype(ndarray_type: Any) -> Any: def is_numpy_number_type(t: type) -> bool: - return isinstance(t, type) and issubclass(t, np.number) + return isinstance(t, type) and issubclass(t, (np.integer, np.floating)) def is_namedtuple_type(t: type) -> bool: @@ -273,8 +274,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo: variant = AnalyzedListType(elem_type=elem_type, vector_info=vector_info) elif base_type is np.ndarray: np_number_type = t - elem_type = extract_ndarray_scalar_dtype(np_number_type) - _ = DtypeRegistry.validate_dtype_and_get_kind(elem_type) + elem_type = extract_ndarray_elem_dtype(np_number_type) variant = AnalyzedListType(elem_type=elem_type, vector_info=vector_info) elif base_type is collections.abc.Mapping or base_type is dict or t is dict: key_type = type_args[0] if len(type_args) > 0 else None