Skip to content

Commit 98026b5

Browse files
committed
feat: support dtype decoding by adding np_number_type to AnalyzedTypeInfo
1 parent 76d1304 commit 98026b5

File tree

3 files changed

+40
-28
lines changed

3 files changed

+40
-28
lines changed

python/cocoindex/convert.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,8 +127,7 @@ def decode(value: Any) -> Any | None:
127127
return lambda value: uuid.UUID(bytes=value)
128128

129129
if src_type_kind == "Vector":
130-
elem_coco_type_info = analyze_type_info(dst_type_info.elem_type)
131-
dtype_info = DtypeRegistry.get_by_kind(elem_coco_type_info.kind)
130+
dtype_info = DtypeRegistry.get_by_dtype(dst_type_info.np_number_type)
132131

133132
def decode_vector(value: Any) -> Any | None:
134133
if value is None:

python/cocoindex/tests/test_typing.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ def test_ndarray_float32_no_dim():
4848
elem_type=Float32,
4949
key_type=None,
5050
struct_type=None,
51+
np_number_type=np.float32,
5152
attrs=None,
5253
nullable=False,
5354
)
@@ -62,6 +63,7 @@ def test_vector_float32_no_dim():
6263
elem_type=Float32,
6364
key_type=None,
6465
struct_type=None,
66+
np_number_type=np.float32,
6567
attrs=None,
6668
nullable=False,
6769
)
@@ -76,6 +78,7 @@ def test_ndarray_float64_with_dim():
7678
elem_type=Float64,
7779
key_type=None,
7880
struct_type=None,
81+
np_number_type=np.float64,
7982
attrs=None,
8083
nullable=False,
8184
)
@@ -90,6 +93,7 @@ def test_vector_float32_with_dim():
9093
elem_type=Float32,
9194
key_type=None,
9295
struct_type=None,
96+
np_number_type=np.float32,
9397
attrs=None,
9498
nullable=False,
9599
)
@@ -109,7 +113,7 @@ def test_ndarray_int32_with_dim():
109113
result = analyze_type_info(typ)
110114
assert result.kind == "Vector"
111115
assert result.vector_info == VectorInfo(dim=10)
112-
assert get_args(result.elem_type) == (int, TypeKind("Int32"))
116+
assert get_args(result.elem_type) == (int, TypeKind("Int64"))
113117
assert not result.nullable
114118

115119

@@ -118,7 +122,7 @@ def test_ndarray_uint8_no_dim():
118122
result = analyze_type_info(typ)
119123
assert result.kind == "Vector"
120124
assert result.vector_info == VectorInfo(dim=None)
121-
assert get_args(result.elem_type) == (int, TypeKind("UInt8"))
125+
assert get_args(result.elem_type) == (int, TypeKind("Int64"))
122126
assert not result.nullable
123127

124128

@@ -131,6 +135,7 @@ def test_nullable_ndarray():
131135
elem_type=Float32,
132136
key_type=None,
133137
struct_type=None,
138+
np_number_type=np.float32,
134139
attrs=None,
135140
nullable=True,
136141
)
@@ -177,6 +182,7 @@ def test_list_of_primitives():
177182
elem_type=str,
178183
key_type=None,
179184
struct_type=None,
185+
np_number_type=None,
180186
attrs=None,
181187
nullable=False,
182188
)
@@ -191,6 +197,7 @@ def test_list_of_structs():
191197
elem_type=SimpleDataclass,
192198
key_type=None,
193199
struct_type=None,
200+
np_number_type=None,
194201
attrs=None,
195202
nullable=False,
196203
)
@@ -205,6 +212,7 @@ def test_sequence_of_int():
205212
elem_type=int,
206213
key_type=None,
207214
struct_type=None,
215+
np_number_type=None,
208216
attrs=None,
209217
nullable=False,
210218
)
@@ -219,6 +227,7 @@ def test_list_with_vector_info():
219227
elem_type=int,
220228
key_type=None,
221229
struct_type=None,
230+
np_number_type=None,
222231
attrs=None,
223232
nullable=False,
224233
)
@@ -233,6 +242,7 @@ def test_dict_str_int():
233242
elem_type=(str, int),
234243
key_type=None,
235244
struct_type=None,
245+
np_number_type=None,
236246
attrs=None,
237247
nullable=False,
238248
)
@@ -247,6 +257,7 @@ def test_mapping_str_dataclass():
247257
elem_type=(str, SimpleDataclass),
248258
key_type=None,
249259
struct_type=None,
260+
np_number_type=None,
250261
attrs=None,
251262
nullable=False,
252263
)
@@ -261,6 +272,7 @@ def test_dataclass():
261272
elem_type=None,
262273
key_type=None,
263274
struct_type=SimpleDataclass,
275+
np_number_type=None,
264276
attrs=None,
265277
nullable=False,
266278
)
@@ -275,6 +287,7 @@ def test_named_tuple():
275287
elem_type=None,
276288
key_type=None,
277289
struct_type=SimpleNamedTuple,
290+
np_number_type=None,
278291
attrs=None,
279292
nullable=False,
280293
)
@@ -289,6 +302,7 @@ def test_tuple_key_value():
289302
elem_type=None,
290303
key_type=str,
291304
struct_type=None,
305+
np_number_type=None,
292306
attrs=None,
293307
nullable=False,
294308
)
@@ -303,6 +317,7 @@ def test_str():
303317
elem_type=None,
304318
key_type=None,
305319
struct_type=None,
320+
np_number_type=None,
306321
attrs=None,
307322
nullable=False,
308323
)
@@ -317,6 +332,7 @@ def test_bool():
317332
elem_type=None,
318333
key_type=None,
319334
struct_type=None,
335+
np_number_type=None,
320336
attrs=None,
321337
nullable=False,
322338
)
@@ -331,6 +347,7 @@ def test_bytes():
331347
elem_type=None,
332348
key_type=None,
333349
struct_type=None,
350+
np_number_type=None,
334351
attrs=None,
335352
nullable=False,
336353
)
@@ -345,6 +362,7 @@ def test_uuid():
345362
elem_type=None,
346363
key_type=None,
347364
struct_type=None,
365+
np_number_type=None,
348366
attrs=None,
349367
nullable=False,
350368
)
@@ -359,6 +377,7 @@ def test_date():
359377
elem_type=None,
360378
key_type=None,
361379
struct_type=None,
380+
np_number_type=None,
362381
attrs=None,
363382
nullable=False,
364383
)
@@ -373,6 +392,7 @@ def test_time():
373392
elem_type=None,
374393
key_type=None,
375394
struct_type=None,
395+
np_number_type=None,
376396
attrs=None,
377397
nullable=False,
378398
)
@@ -387,6 +407,7 @@ def test_timedelta():
387407
elem_type=None,
388408
key_type=None,
389409
struct_type=None,
410+
np_number_type=None,
390411
attrs=None,
391412
nullable=False,
392413
)
@@ -401,6 +422,7 @@ def test_float():
401422
elem_type=None,
402423
key_type=None,
403424
struct_type=None,
425+
np_number_type=None,
404426
attrs=None,
405427
nullable=False,
406428
)
@@ -415,6 +437,7 @@ def test_int():
415437
elem_type=None,
416438
key_type=None,
417439
struct_type=None,
440+
np_number_type=None,
418441
attrs=None,
419442
nullable=False,
420443
)
@@ -429,6 +452,7 @@ def test_type_with_attributes():
429452
elem_type=None,
430453
key_type=None,
431454
struct_type=None,
455+
np_number_type=None,
432456
attrs={"key": "value"},
433457
nullable=False,
434458
)

python/cocoindex/typing.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,11 @@ class DtypeRegistry:
119119
_mappings: dict[type, DtypeInfo] = {
120120
np.float32: DtypeInfo(np.float32, "Float32", float),
121121
np.float64: DtypeInfo(np.float64, "Float64", float),
122-
np.int32: DtypeInfo(np.int32, "Int32", int),
122+
np.int32: DtypeInfo(np.int32, "Int64", int),
123123
np.int64: DtypeInfo(np.int64, "Int64", int),
124-
np.uint8: DtypeInfo(np.uint8, "UInt8", int),
125-
np.uint16: DtypeInfo(np.uint16, "UInt16", int),
126-
np.uint32: DtypeInfo(np.uint32, "UInt32", int),
124+
np.uint8: DtypeInfo(np.uint8, "Int64", int),
125+
np.uint16: DtypeInfo(np.uint16, "Int64", int),
126+
np.uint32: DtypeInfo(np.uint32, "Int64", int),
127127
}
128128

129129
@classmethod
@@ -135,20 +135,6 @@ def get_by_dtype(cls, dtype: Any) -> DtypeInfo | None:
135135
)
136136
return cls._mappings.get(dtype)
137137

138-
@staticmethod
139-
def get_by_kind(kind: str) -> DtypeInfo | None:
140-
"""Get DtypeInfo by kind."""
141-
return next(
142-
(info for info in DtypeRegistry._mappings.values() if info.kind == kind),
143-
None,
144-
)
145-
146-
@staticmethod
147-
def rust_compatible_kind(kind: str) -> str:
148-
"""Map to a Rust-compatible kind for schema encoding."""
149-
# incompatible_integer_kinds = {"Int32", "UInt8", "UInt16", "UInt32", "UInt64"}
150-
return "Int64" if "Int" in kind else kind
151-
152138
@staticmethod
153139
def supported_dtypes() -> KeysView[type]:
154140
"""Get a list of supported NumPy dtypes."""
@@ -167,6 +153,9 @@ class AnalyzedTypeInfo:
167153

168154
key_type: type | None # For element of KTable
169155
struct_type: type | None # For Struct, a dataclass or namedtuple
156+
np_number_type: (
157+
type | None
158+
) # NumPy dtype for the element type, if represented by numpy.ndarray or a NumPy scalar
170159

171160
attrs: dict[str, Any] | None
172161
nullable: bool = False
@@ -221,6 +210,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
221210
struct_type: type | None = None
222211
elem_type: ElementType | None = None
223212
key_type: type | None = None
213+
np_number_type: type | None = None
224214
if _is_struct_type(t):
225215
struct_type = t
226216

@@ -254,11 +244,11 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
254244
if not dtype_args:
255245
raise ValueError("Invalid dtype specification for NDArray")
256246

257-
numpy_dtype = dtype_args[0]
258-
dtype_info = DtypeRegistry.get_by_dtype(numpy_dtype)
247+
np_number_type = dtype_args[0]
248+
dtype_info = DtypeRegistry.get_by_dtype(np_number_type)
259249
if dtype_info is None:
260250
raise ValueError(
261-
f"Unsupported numpy dtype for NDArray: {numpy_dtype}. "
251+
f"Unsupported numpy dtype for NDArray: {np_number_type}. "
262252
f"Supported dtypes: {DtypeRegistry.supported_dtypes()}"
263253
)
264254
elem_type = dtype_info.annotated_type
@@ -272,6 +262,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
272262
dtype_info = DtypeRegistry.get_by_dtype(t)
273263
if dtype_info is not None:
274264
kind = dtype_info.kind
265+
np_number_type = dtype_info.numpy_dtype
275266
elif t is bytes:
276267
kind = "Bytes"
277268
elif t is str:
@@ -301,6 +292,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
301292
elem_type=elem_type,
302293
key_type=key_type,
303294
struct_type=struct_type,
295+
np_number_type=np_number_type,
304296
attrs=attrs,
305297
nullable=nullable,
306298
)
@@ -355,9 +347,6 @@ def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]:
355347
raise ValueError("Vector type must have an element type")
356348
elem_type_info = analyze_type_info(type_info.elem_type)
357349
encoded_type["element_type"] = _encode_type(elem_type_info)
358-
encoded_type["element_type"]["kind"] = DtypeRegistry.rust_compatible_kind(
359-
elem_type_info.kind
360-
)
361350
encoded_type["dimension"] = type_info.vector_info.dim
362351

363352
elif type_info.kind in TABLE_TYPES:

0 commit comments

Comments
 (0)