Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/docs/core/data_types.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,8 @@ Optionally, it can have a fixed dimension. Noted as *Vector[Type]* or *Vector[Ty
It supports the following Python types:

* `cocoindex.Vector[T]` or `cocoindex.Vector[T, typing.Literal[Dim]]`, e.g. `cocoindex.Vector[cocoindex.Float32]` or `cocoindex.Vector[cocoindex.Float32, typing.Literal[384]]`
* The underlying Python type is `numpy.typing.NDArray[T]` where `T` is a numpy numeric type (`numpy.int64`, `numpy.float32` or `numpy.float64`), or `list[T]` otherwise
* `numpy.typing.NDArray[T]` where `T` is a numpy numeric type
* The underlying Python type is `numpy.typing.NDArray[T]` where `T` is a numpy numeric type (`numpy.int64`, `numpy.float32` or `numpy.float64`) or array type (`numpy.typing.NDArray[T]`), or `list[T]` otherwise
* `numpy.typing.NDArray[T]` where `T` is a numpy numeric type or array type
* `list[T]`


Expand Down
9 changes: 3 additions & 6 deletions python/cocoindex/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,9 @@ def decode(value: Any) -> Any | None:

vec_elem_decoder = None
scalar_dtype = None
if (
dst_type_variant
and is_numpy_number_type(dst_type_variant.elem_type)
and dst_type_info.base_type is np.ndarray
):
scalar_dtype = dst_type_variant.elem_type
if dst_type_variant and dst_type_info.base_type is np.ndarray:
if is_numpy_number_type(dst_type_variant.elem_type):
scalar_dtype = dst_type_variant.elem_type
else:
vec_elem_decoder = make_engine_value_decoder(
field_path + ["[*]"],
Expand Down
23 changes: 22 additions & 1 deletion python/cocoindex/tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ def eq(a: Any, b: Any) -> bool:
for other_value, other_type in decoded_values:
decoder = make_engine_value_decoder([], encoded_output_type, other_type)
other_decoded_value = decoder(value_from_engine)
assert eq(other_decoded_value, other_value)
assert eq(other_decoded_value, other_value), (
f"Expected {other_value} but got {other_decoded_value} for {other_type}"
)


def validate_full_roundtrip(
Expand Down Expand Up @@ -1096,6 +1098,25 @@ def test_full_roundtrip_vector_numeric_types() -> None:
validate_full_roundtrip(value_u64, Vector[np.uint64, Literal[3]])


def test_full_roundtrip_vector_of_vector() -> None:
"""Test full roundtrip for vector of vector."""
value_f32 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)
validate_full_roundtrip(
value_f32,
Vector[Vector[np.float32, Literal[3]], Literal[2]],
([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], list[list[np.float32]]),
([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], list[list[cocoindex.Float32]]),
(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
list[Vector[cocoindex.Float32, Literal[3]]],
),
(
value_f32,
np.typing.NDArray[np.typing.NDArray[np.float32]],
),
)


def test_full_roundtrip_vector_other_types() -> None:
"""Test full roundtrip for Vector with non-numeric basic types."""
uuid_list = [uuid.uuid4(), uuid.uuid4()]
Expand Down
25 changes: 7 additions & 18 deletions python/cocoindex/tests/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,14 +124,6 @@ def test_vector_str() -> None:
assert result.variant.vector_info == VectorInfo(dim=None)


def test_vector_complex64() -> None:
typ = Vector[np.complex64]
result = analyze_type_info(typ)
assert isinstance(result.variant, AnalyzedListType)
assert result.variant.elem_type == np.complex64
assert result.variant.vector_info == VectorInfo(dim=None)


def test_non_numpy_vector() -> None:
typ = Vector[float, Literal[3]]
result = analyze_type_info(typ)
Expand All @@ -140,14 +132,6 @@ def test_non_numpy_vector() -> None:
assert result.variant.vector_info == VectorInfo(dim=3)


def test_ndarray_any_dtype() -> None:
typ = NDArray[Any]
with pytest.raises(
TypeError, match="NDArray for Vector must use a concrete numpy dtype"
):
analyze_type_info(typ)


def test_list_of_primitives() -> None:
typ = list[str]
result = analyze_type_info(typ)
Expand Down Expand Up @@ -439,9 +423,14 @@ def test_annotated_list_with_type_kind() -> None:


def test_unsupported_type() -> None:
typ = set
with pytest.raises(
ValueError,
match="Unsupported as a specific type annotation for CocoIndex data type.*: <class 'set'>",
):
analyze_type_info(typ)
analyze_type_info(set)

with pytest.raises(
ValueError,
match="Unsupported as a specific type annotation for CocoIndex data type.*: <class 'numpy.complex64'>",
):
Vector[np.complex64]
22 changes: 11 additions & 11 deletions python/cocoindex/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,7 @@ def __class_getitem__(self, params):
if not isinstance(params, tuple):
# No dimension provided, e.g., Vector[np.float32]
dtype = params
# Use NDArray for supported numeric dtypes, else list
if dtype in DtypeRegistry._DTYPE_TO_KIND:
return Annotated[NDArray[dtype], VectorInfo(dim=None)]
return Annotated[list[dtype], VectorInfo(dim=None)]
vector_info = VectorInfo(dim=None)
else:
# Element type and dimension provided, e.g., Vector[np.float32, Literal[3]]
dtype, dim_literal = params
Expand All @@ -80,16 +77,20 @@ def __class_getitem__(self, params):
if typing.get_origin(dim_literal) is Literal
else None
)
if dtype in DtypeRegistry._DTYPE_TO_KIND:
return Annotated[NDArray[dtype], VectorInfo(dim=dim_val)]
return Annotated[list[dtype], VectorInfo(dim=dim_val)]
vector_info = VectorInfo(dim=dim_val)

# Use NDArray for supported numeric dtypes, else list
base_type = analyze_type_info(dtype).base_type
if is_numpy_number_type(base_type) or base_type is np.ndarray:
return Annotated[NDArray[dtype], vector_info]
return Annotated[list[dtype], vector_info]


TABLE_TYPES: tuple[str, str] = ("KTable", "LTable")
KEY_FIELD_NAME: str = "_key"


def extract_ndarray_scalar_dtype(ndarray_type: Any) -> Any:
def extract_ndarray_elem_dtype(ndarray_type: Any) -> Any:
args = typing.get_args(ndarray_type)
_, dtype_spec = args
dtype_args = typing.get_args(dtype_spec)
Expand All @@ -99,7 +100,7 @@ def extract_ndarray_scalar_dtype(ndarray_type: Any) -> Any:


def is_numpy_number_type(t: type) -> bool:
return isinstance(t, type) and issubclass(t, np.number)
return isinstance(t, type) and issubclass(t, (np.integer, np.floating))


def is_namedtuple_type(t: type) -> bool:
Expand Down Expand Up @@ -273,8 +274,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
variant = AnalyzedListType(elem_type=elem_type, vector_info=vector_info)
elif base_type is np.ndarray:
np_number_type = t
elem_type = extract_ndarray_scalar_dtype(np_number_type)
_ = DtypeRegistry.validate_dtype_and_get_kind(elem_type)
elem_type = extract_ndarray_elem_dtype(np_number_type)
variant = AnalyzedListType(elem_type=elem_type, vector_info=vector_info)
elif base_type is collections.abc.Mapping or base_type is dict or t is dict:
key_type = type_args[0] if len(type_args) > 0 else None
Expand Down
Loading