Skip to content

Commit 1b02448

Browse files
committed
refactor: simplify DtypeRegistry structure and associated methods
1 parent 2f4ac25 commit 1b02448

File tree

3 files changed

+53
-70
lines changed

3 files changed

+53
-70
lines changed

python/cocoindex/convert.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,19 @@
66
import datetime
77
import inspect
88
import uuid
9+
from enum import Enum
10+
from typing import Any, Callable, Mapping, get_origin
11+
912
import numpy as np
1013

11-
from enum import Enum
12-
from typing import Any, Callable, get_origin, Mapping
1314
from .typing import (
15+
KEY_FIELD_NAME,
16+
TABLE_TYPES,
17+
DtypeRegistry,
1418
analyze_type_info,
1519
encode_enriched_type,
1620
extract_ndarray_scalar_dtype,
1721
is_namedtuple_type,
18-
TABLE_TYPES,
19-
KEY_FIELD_NAME,
20-
DtypeRegistry,
2122
)
2223

2324

@@ -178,12 +179,7 @@ def decode_vector(value: Any) -> Any | None:
178179
scalar_dtype = extract_ndarray_scalar_dtype(
179180
dst_type_info.np_number_type
180181
)
181-
dtype_info = DtypeRegistry.get_by_dtype(scalar_dtype)
182-
if dtype_info is None:
183-
raise ValueError(
184-
f"Unsupported dtype in NDArray: {scalar_dtype}. "
185-
f"Supported dtypes: {DtypeRegistry.supported_dtypes()}"
186-
)
182+
_ = DtypeRegistry.validate_and_get_dtype_info(scalar_dtype)
187183
return np.array(value, dtype=scalar_dtype)
188184

189185
return decode_vector

python/cocoindex/tests/test_typing.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010

1111
from cocoindex.typing import (
1212
AnalyzedTypeInfo,
13-
Float32,
14-
Float64,
1513
TypeAttr,
1614
TypeKind,
1715
Vector,
@@ -37,7 +35,7 @@ def test_ndarray_float32_no_dim() -> None:
3735
result = analyze_type_info(typ)
3836
assert result.kind == "Vector"
3937
assert result.vector_info == VectorInfo(dim=None)
40-
assert result.elem_type == Float32
38+
assert result.elem_type == np.float32
4139
assert result.key_type is None
4240
assert result.struct_type is None
4341
assert result.nullable is False
@@ -51,7 +49,7 @@ def test_vector_float32_no_dim() -> None:
5149
result = analyze_type_info(typ)
5250
assert result.kind == "Vector"
5351
assert result.vector_info == VectorInfo(dim=None)
54-
assert result.elem_type == Float32
52+
assert result.elem_type == np.float32
5553
assert result.key_type is None
5654
assert result.struct_type is None
5755
assert result.nullable is False
@@ -65,7 +63,7 @@ def test_ndarray_float64_with_dim() -> None:
6563
result = analyze_type_info(typ)
6664
assert result.kind == "Vector"
6765
assert result.vector_info == VectorInfo(dim=128)
68-
assert result.elem_type == Float64
66+
assert result.elem_type == np.float64
6967
assert result.key_type is None
7068
assert result.struct_type is None
7169
assert result.nullable is False
@@ -79,7 +77,7 @@ def test_vector_float32_with_dim() -> None:
7977
result = analyze_type_info(typ)
8078
assert result.kind == "Vector"
8179
assert result.vector_info == VectorInfo(dim=384)
82-
assert result.elem_type == Float32
80+
assert result.elem_type == np.float32
8381
assert result.key_type is None
8482
assert result.struct_type is None
8583
assert result.nullable is False
@@ -93,7 +91,7 @@ def test_ndarray_int64_no_dim() -> None:
9391
result = analyze_type_info(typ)
9492
assert result.kind == "Vector"
9593
assert result.vector_info == VectorInfo(dim=None)
96-
assert get_args(result.elem_type) == (int, TypeKind("Int64"))
94+
assert result.elem_type == np.int64
9795
assert result.nullable is False
9896
assert result.np_number_type is not None
9997
assert get_origin(result.np_number_type) == np.ndarray
@@ -105,7 +103,7 @@ def test_nullable_ndarray() -> None:
105103
result = analyze_type_info(typ)
106104
assert result.kind == "Vector"
107105
assert result.vector_info == VectorInfo(dim=None)
108-
assert result.elem_type == Float32
106+
assert result.elem_type == np.float32
109107
assert result.key_type is None
110108
assert result.struct_type is None
111109
assert result.nullable is True
@@ -121,12 +119,12 @@ def test_scalar_numpy_types() -> None:
121119
(np.float64, "Float64"),
122120
]:
123121
type_info = analyze_type_info(np_type)
124-
assert type_info.kind == expected_kind, (
125-
f"Expected {expected_kind} for {np_type}, got {type_info.kind}"
126-
)
127-
assert type_info.np_number_type == np_type, (
128-
f"Expected {np_type}, got {type_info.np_number_type}"
129-
)
122+
assert (
123+
type_info.kind == expected_kind
124+
), f"Expected {expected_kind} for {np_type}, got {type_info.kind}"
125+
assert (
126+
type_info.np_number_type == np_type
127+
), f"Expected {np_type}, got {type_info.np_number_type}"
130128
assert type_info.elem_type is None
131129
assert type_info.vector_info is None
132130

@@ -502,9 +500,9 @@ def test_encode_scalar_numpy_types_schema() -> None:
502500
(np.float64, "Float64"),
503501
]:
504502
schema = encode_enriched_type(np_type)
505-
assert schema["type"]["kind"] == expected_kind, (
506-
f"Expected {expected_kind} for {np_type}, got {schema['type']['kind']}"
507-
)
503+
assert (
504+
schema["type"]["kind"] == expected_kind
505+
), f"Expected {expected_kind} for {np_type}, got {schema['type']['kind']}"
508506
assert not schema.get("nullable", False)
509507

510508

python/cocoindex/typing.py

Lines changed: 31 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
1-
import typing
21
import collections
32
import dataclasses
43
import datetime
5-
import types
64
import inspect
5+
import types
6+
import typing
77
import uuid
88
from typing import (
9+
TYPE_CHECKING,
910
Annotated,
10-
NamedTuple,
1111
Any,
12-
KeysView,
13-
TypeVar,
14-
TYPE_CHECKING,
15-
overload,
1612
Generic,
1713
Literal,
14+
NamedTuple,
1815
Protocol,
16+
TypeVar,
17+
overload,
1918
)
19+
2020
import numpy as np
2121
from numpy.typing import NDArray
2222

@@ -113,41 +113,40 @@ def _is_struct_type(t: ElementType | None) -> bool:
113113
)
114114

115115

116-
class DtypeInfo:
117-
"""Metadata for a NumPy dtype."""
118-
119-
def __init__(self, numpy_dtype: type, kind: str, python_type: type) -> None:
120-
self.numpy_dtype = numpy_dtype
121-
self.kind = kind
122-
self.python_type = python_type
123-
self.annotated_type = Annotated[python_type, TypeKind(kind)]
124-
125-
126116
class DtypeRegistry:
127117
"""
128118
Registry for NumPy dtypes used in CocoIndex.
129-
Provides mappings from NumPy dtypes to CocoIndex's type representation.
119+
Maps NumPy dtypes to their CocoIndex type kind.
130120
"""
131121

132-
_mappings: dict[type, DtypeInfo] = {
133-
np.float32: DtypeInfo(np.float32, "Float32", float),
134-
np.float64: DtypeInfo(np.float64, "Float64", float),
135-
np.int64: DtypeInfo(np.int64, "Int64", int),
122+
_DTYPE_TO_KIND: dict[type, str] = {
123+
np.float32: "Float32",
124+
np.float64: "Float64",
125+
np.int64: "Int64",
136126
}
137127

138128
@classmethod
139-
def get_by_dtype(cls, dtype: Any) -> DtypeInfo | None:
140-
"""Get DtypeInfo by NumPy dtype."""
129+
def get_by_dtype(cls, dtype: Any) -> tuple[type, str] | None:
130+
"""Get the NumPy dtype and its CocoIndex kind by dtype."""
141131
if dtype is Any:
142132
raise TypeError(
143133
"NDArray for Vector must use a concrete numpy dtype, got `Any`."
144134
)
145-
return cls._mappings.get(dtype)
135+
kind = cls._DTYPE_TO_KIND.get(dtype)
136+
return None if kind is None else (dtype, kind)
146137

147-
@staticmethod
148-
def supported_dtypes() -> KeysView[type]:
149-
"""Get a list of supported NumPy dtypes."""
150-
return DtypeRegistry._mappings.keys()
138+
@classmethod
139+
def validate_and_get_dtype_info(cls, dtype: Any) -> tuple[type, str]:
140+
"""
141+
Validate that the given dtype is supported.
142+
"""
143+
dtype_info = cls.get_by_dtype(dtype)
144+
if dtype_info is None:
145+
raise ValueError(
146+
f"Unsupported NumPy dtype in NDArray: {dtype}. "
147+
f"Supported dtypes: {cls._DTYPE_TO_KIND.keys()}"
148+
)
149+
return dtype_info
151150

152151

153152
@dataclasses.dataclass
@@ -228,11 +227,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
228227
elif kind != "Struct":
229228
raise ValueError(f"Unexpected type kind for struct: {kind}")
230229
elif is_numpy_number_type(t):
231-
if (dtype_info := DtypeRegistry.get_by_dtype(t)) is not None:
232-
kind = dtype_info.kind
233-
np_number_type = dtype_info.numpy_dtype
234-
else:
235-
raise ValueError(f"Unsupported NumPy dtype: {t}")
230+
np_number_type, kind = DtypeRegistry.validate_and_get_dtype_info(t)
236231
elif base_type is collections.abc.Sequence or base_type is list:
237232
args = typing.get_args(t)
238233
elem_type = args[0]
@@ -253,14 +248,8 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
253248
elif base_type is np.ndarray:
254249
kind = "Vector"
255250
np_number_type = t
256-
numpy_dtype = extract_ndarray_scalar_dtype(np_number_type)
257-
dtype_info = DtypeRegistry.get_by_dtype(numpy_dtype)
258-
if dtype_info is None:
259-
raise ValueError(
260-
f"Unsupported numpy dtype for NDArray: {numpy_dtype}. "
261-
f"Supported dtypes: {DtypeRegistry.supported_dtypes()}"
262-
)
263-
elem_type = dtype_info.annotated_type
251+
elem_type = extract_ndarray_scalar_dtype(np_number_type)
252+
_ = DtypeRegistry.validate_and_get_dtype_info(elem_type)
264253
vector_info = VectorInfo(dim=None) if vector_info is None else vector_info
265254

266255
elif base_type is collections.abc.Mapping or base_type is dict:

0 commit comments

Comments
 (0)