Skip to content

Commit 4848c50

Browse files
committed
feat: remove DtypeRegistry getter method
1 parent 48bbbaf commit 4848c50

File tree

3 files changed

+13
-24
lines changed

3 files changed

+13
-24
lines changed

python/cocoindex/convert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ def decode_vector(value: Any) -> Any | None:
179179
scalar_dtype = extract_ndarray_scalar_dtype(
180180
dst_type_info.np_number_type
181181
)
182-
_ = DtypeRegistry.validate_and_get_dtype_info(scalar_dtype)
182+
_ = DtypeRegistry.validate_dtype_and_get_kind(scalar_dtype)
183183
return np.array(value, dtype=scalar_dtype)
184184

185185
return decode_vector

python/cocoindex/tests/test_convert.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -622,12 +622,6 @@ def test_vector_as_list() -> None:
622622
Float32VectorType = Vector[np.float32, Literal[3]]
623623
Float64VectorType = Vector[np.float64, Literal[3]]
624624
Int64VectorType = Vector[np.int64, Literal[3]]
625-
Int32VectorType = Vector[np.int32, Literal[3]]
626-
UInt8VectorType = Vector[np.uint8, Literal[3]]
627-
UInt16VectorType = Vector[np.uint16, Literal[3]]
628-
UInt32VectorType = Vector[np.uint32, Literal[3]]
629-
UInt64VectorType = Vector[np.uint64, Literal[3]]
630-
StrVectorType = Vector[str]
631625
NDArrayFloat32Type = NDArray[np.float32]
632626
NDArrayFloat64Type = NDArray[np.float64]
633627
NDArrayInt64Type = NDArray[np.int64]

python/cocoindex/typing.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def __class_getitem__(self, params):
6767
# No dimension provided, e.g., Vector[np.float32]
6868
dtype = params
6969
# Use NDArray for supported numeric dtypes, else list
70-
if DtypeRegistry.get_by_dtype(dtype) is not None:
70+
if dtype in DtypeRegistry._DTYPE_TO_KIND:
7171
return Annotated[NDArray[dtype], VectorInfo(dim=None)]
7272
return Annotated[list[dtype], VectorInfo(dim=None)]
7373
else:
@@ -79,7 +79,7 @@ def __class_getitem__(self, params):
7979
if typing.get_origin(dim_literal) is Literal
8080
else None
8181
)
82-
if DtypeRegistry.get_by_dtype(dtype) is not None:
82+
if dtype in DtypeRegistry._DTYPE_TO_KIND:
8383
return Annotated[NDArray[dtype], VectorInfo(dim=dim_val)]
8484
return Annotated[list[dtype], VectorInfo(dim=dim_val)]
8585

@@ -119,34 +119,28 @@ class DtypeRegistry:
119119
Maps NumPy dtypes to their CocoIndex type kind.
120120
"""
121121

122-
_DTYPE_TO_KIND: dict[type, str] = {
122+
_DTYPE_TO_KIND: dict[ElementType, str] = {
123123
np.float32: "Float32",
124124
np.float64: "Float64",
125125
np.int64: "Int64",
126126
}
127127

128128
@classmethod
129-
def get_by_dtype(cls, dtype: Any) -> tuple[type, str] | None:
130-
"""Get the NumPy dtype and its CocoIndex kind by dtype."""
129+
def validate_dtype_and_get_kind(cls, dtype: ElementType) -> str:
130+
"""
131+
Validate that the given dtype is supported, and get its CocoIndex kind by dtype.
132+
"""
131133
if dtype is Any:
132134
raise TypeError(
133135
"NDArray for Vector must use a concrete numpy dtype, got `Any`."
134136
)
135137
kind = cls._DTYPE_TO_KIND.get(dtype)
136-
return None if kind is None else (dtype, kind)
137-
138-
@classmethod
139-
def validate_and_get_dtype_info(cls, dtype: Any) -> tuple[type, str]:
140-
"""
141-
Validate that the given dtype is supported.
142-
"""
143-
dtype_info = cls.get_by_dtype(dtype)
144-
if dtype_info is None:
138+
if kind is None:
145139
raise ValueError(
146140
f"Unsupported NumPy dtype in NDArray: {dtype}. "
147141
f"Supported dtypes: {cls._DTYPE_TO_KIND.keys()}"
148142
)
149-
return dtype_info
143+
return kind
150144

151145

152146
@dataclasses.dataclass
@@ -227,7 +221,8 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
227221
elif kind != "Struct":
228222
raise ValueError(f"Unexpected type kind for struct: {kind}")
229223
elif is_numpy_number_type(t):
230-
np_number_type, kind = DtypeRegistry.validate_and_get_dtype_info(t)
224+
np_number_type = t
225+
kind = DtypeRegistry.validate_dtype_and_get_kind(t)
231226
elif base_type is collections.abc.Sequence or base_type is list:
232227
args = typing.get_args(t)
233228
elem_type = args[0]
@@ -249,7 +244,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
249244
kind = "Vector"
250245
np_number_type = t
251246
elem_type = extract_ndarray_scalar_dtype(np_number_type)
252-
_ = DtypeRegistry.validate_and_get_dtype_info(elem_type)
247+
_ = DtypeRegistry.validate_dtype_and_get_kind(elem_type)
253248
vector_info = VectorInfo(dim=None) if vector_info is None else vector_info
254249

255250
elif base_type is collections.abc.Mapping or base_type is dict:

0 commit comments

Comments
 (0)