Skip to content

Commit 9ea8ca7

Browse files
authored
feat(md-vector): support multi-dimensional vectors (#801)
1 parent 70f3b71 commit 9ea8ca7

File tree

5 files changed

+45
-38
lines changed

5 files changed

+45
-38
lines changed

docs/docs/core/data_types.mdx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ Optionally, it can have a fixed dimension. Noted as *Vector[Type]* or *Vector[Ty
8888
It supports the following Python types:
8989

9090
* `cocoindex.Vector[T]` or `cocoindex.Vector[T, typing.Literal[Dim]]`, e.g. `cocoindex.Vector[cocoindex.Float32]` or `cocoindex.Vector[cocoindex.Float32, typing.Literal[384]]`
91-
* The underlying Python type is `numpy.typing.NDArray[T]` where `T` is a numpy numeric type (`numpy.int64`, `numpy.float32` or `numpy.float64`), or `list[T]` otherwise
92-
* `numpy.typing.NDArray[T]` where `T` is a numpy numeric type
91+
* The underlying Python type is `numpy.typing.NDArray[T]` where `T` is a numpy numeric type (`numpy.int64`, `numpy.float32` or `numpy.float64`) or array type (`numpy.typing.NDArray[T]`), or `list[T]` otherwise
92+
* `numpy.typing.NDArray[T]` where `T` is a numpy numeric type or array type
9393
* `list[T]`
9494

9595

python/cocoindex/convert.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -209,12 +209,9 @@ def decode(value: Any) -> Any | None:
209209

210210
vec_elem_decoder = None
211211
scalar_dtype = None
212-
if (
213-
dst_type_variant
214-
and is_numpy_number_type(dst_type_variant.elem_type)
215-
and dst_type_info.base_type is np.ndarray
216-
):
217-
scalar_dtype = dst_type_variant.elem_type
212+
if dst_type_variant and dst_type_info.base_type is np.ndarray:
213+
if is_numpy_number_type(dst_type_variant.elem_type):
214+
scalar_dtype = dst_type_variant.elem_type
218215
else:
219216
vec_elem_decoder = make_engine_value_decoder(
220217
field_path + ["[*]"],

python/cocoindex/tests/test_convert.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,9 @@ def eq(a: Any, b: Any) -> bool:
105105
for other_value, other_type in decoded_values:
106106
decoder = make_engine_value_decoder([], encoded_output_type, other_type)
107107
other_decoded_value = decoder(value_from_engine)
108-
assert eq(other_decoded_value, other_value)
108+
assert eq(other_decoded_value, other_value), (
109+
f"Expected {other_value} but got {other_decoded_value} for {other_type}"
110+
)
109111

110112

111113
def validate_full_roundtrip(
@@ -1096,6 +1098,25 @@ def test_full_roundtrip_vector_numeric_types() -> None:
10961098
validate_full_roundtrip(value_u64, Vector[np.uint64, Literal[3]])
10971099

10981100

1101+
def test_full_roundtrip_vector_of_vector() -> None:
1102+
"""Test full roundtrip for vector of vector."""
1103+
value_f32 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)
1104+
validate_full_roundtrip(
1105+
value_f32,
1106+
Vector[Vector[np.float32, Literal[3]], Literal[2]],
1107+
([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], list[list[np.float32]]),
1108+
([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], list[list[cocoindex.Float32]]),
1109+
(
1110+
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
1111+
list[Vector[cocoindex.Float32, Literal[3]]],
1112+
),
1113+
(
1114+
value_f32,
1115+
np.typing.NDArray[np.typing.NDArray[np.float32]],
1116+
),
1117+
)
1118+
1119+
10991120
def test_full_roundtrip_vector_other_types() -> None:
11001121
"""Test full roundtrip for Vector with non-numeric basic types."""
11011122
uuid_list = [uuid.uuid4(), uuid.uuid4()]

python/cocoindex/tests/test_typing.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -124,14 +124,6 @@ def test_vector_str() -> None:
124124
assert result.variant.vector_info == VectorInfo(dim=None)
125125

126126

127-
def test_vector_complex64() -> None:
128-
typ = Vector[np.complex64]
129-
result = analyze_type_info(typ)
130-
assert isinstance(result.variant, AnalyzedListType)
131-
assert result.variant.elem_type == np.complex64
132-
assert result.variant.vector_info == VectorInfo(dim=None)
133-
134-
135127
def test_non_numpy_vector() -> None:
136128
typ = Vector[float, Literal[3]]
137129
result = analyze_type_info(typ)
@@ -140,14 +132,6 @@ def test_non_numpy_vector() -> None:
140132
assert result.variant.vector_info == VectorInfo(dim=3)
141133

142134

143-
def test_ndarray_any_dtype() -> None:
144-
typ = NDArray[Any]
145-
with pytest.raises(
146-
TypeError, match="NDArray for Vector must use a concrete numpy dtype"
147-
):
148-
analyze_type_info(typ)
149-
150-
151135
def test_list_of_primitives() -> None:
152136
typ = list[str]
153137
result = analyze_type_info(typ)
@@ -439,9 +423,14 @@ def test_annotated_list_with_type_kind() -> None:
439423

440424

441425
def test_unsupported_type() -> None:
442-
typ = set
443426
with pytest.raises(
444427
ValueError,
445428
match="Unsupported as a specific type annotation for CocoIndex data type.*: <class 'set'>",
446429
):
447-
analyze_type_info(typ)
430+
analyze_type_info(set)
431+
432+
with pytest.raises(
433+
ValueError,
434+
match="Unsupported as a specific type annotation for CocoIndex data type.*: <class 'numpy.complex64'>",
435+
):
436+
Vector[np.complex64]

python/cocoindex/typing.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,7 @@ def __class_getitem__(self, params):
6767
if not isinstance(params, tuple):
6868
# No dimension provided, e.g., Vector[np.float32]
6969
dtype = params
70-
# Use NDArray for supported numeric dtypes, else list
71-
if dtype in DtypeRegistry._DTYPE_TO_KIND:
72-
return Annotated[NDArray[dtype], VectorInfo(dim=None)]
73-
return Annotated[list[dtype], VectorInfo(dim=None)]
70+
vector_info = VectorInfo(dim=None)
7471
else:
7572
# Element type and dimension provided, e.g., Vector[np.float32, Literal[3]]
7673
dtype, dim_literal = params
@@ -80,16 +77,20 @@ def __class_getitem__(self, params):
8077
if typing.get_origin(dim_literal) is Literal
8178
else None
8279
)
83-
if dtype in DtypeRegistry._DTYPE_TO_KIND:
84-
return Annotated[NDArray[dtype], VectorInfo(dim=dim_val)]
85-
return Annotated[list[dtype], VectorInfo(dim=dim_val)]
80+
vector_info = VectorInfo(dim=dim_val)
81+
82+
# Use NDArray for supported numeric dtypes, else list
83+
base_type = analyze_type_info(dtype).base_type
84+
if is_numpy_number_type(base_type) or base_type is np.ndarray:
85+
return Annotated[NDArray[dtype], vector_info]
86+
return Annotated[list[dtype], vector_info]
8687

8788

8889
TABLE_TYPES: tuple[str, str] = ("KTable", "LTable")
8990
KEY_FIELD_NAME: str = "_key"
9091

9192

92-
def extract_ndarray_scalar_dtype(ndarray_type: Any) -> Any:
93+
def extract_ndarray_elem_dtype(ndarray_type: Any) -> Any:
9394
args = typing.get_args(ndarray_type)
9495
_, dtype_spec = args
9596
dtype_args = typing.get_args(dtype_spec)
@@ -99,7 +100,7 @@ def extract_ndarray_scalar_dtype(ndarray_type: Any) -> Any:
99100

100101

101102
def is_numpy_number_type(t: type) -> bool:
102-
return isinstance(t, type) and issubclass(t, np.number)
103+
return isinstance(t, type) and issubclass(t, (np.integer, np.floating))
103104

104105

105106
def is_namedtuple_type(t: type) -> bool:
@@ -273,8 +274,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
273274
variant = AnalyzedListType(elem_type=elem_type, vector_info=vector_info)
274275
elif base_type is np.ndarray:
275276
np_number_type = t
276-
elem_type = extract_ndarray_scalar_dtype(np_number_type)
277-
_ = DtypeRegistry.validate_dtype_and_get_kind(elem_type)
277+
elem_type = extract_ndarray_elem_dtype(np_number_type)
278278
variant = AnalyzedListType(elem_type=elem_type, vector_info=vector_info)
279279
elif base_type is collections.abc.Mapping or base_type is dict or t is dict:
280280
key_type = type_args[0] if len(type_args) > 0 else None

0 commit comments

Comments
 (0)