Skip to content

Commit abd3ad2

Browse files
Add support for typed numpy arrays (#1953)
Add support for typed numpy arrays like below: ``` NDArrayInt = npt.NDArray[np.int_] # numpy array of integers class DetectionSample(Sample): bboxes: NDArrayInt = bbox_field(dtype=pl.Int32) ``` Resolves #1949 <!-- Contributing guide: https://github.com/open-edge-platform/datumaro/blob/develop/contributing.md --> <!-- Please add a summary of changes. You may use Copilot to auto-generate the PR description but please consider including any other relevant facts which Copilot may be unaware of (such as design choices and testing procedure). Add references to the relevant issues and pull requests if any like so: Resolves #111 and #222. Depends on #1000 (for series of dependent commits). --> ### Checklist <!-- Put an 'x' in all the boxes that apply --> - [ ] I have added tests to cover my changes or documented any manual tests. - [ ] I have updated the [documentation](https://github.com/open-edge-platform/datumaro/tree/develop/docs) accordingly --------- Signed-off-by: Albert van Houten <[email protected]> Co-authored-by: Copilot <[email protected]>
1 parent 80d4e9f commit abd3ad2

File tree

2 files changed

+254
-25
lines changed

2 files changed

+254
-25
lines changed

src/datumaro/experimental/type_registry.py

Lines changed: 70 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import logging
1313
import types
1414
from collections.abc import Callable
15-
from typing import Any, Union
15+
from typing import Any, Union, get_args, get_origin
1616

1717
import numpy as np
1818
import polars as pl
@@ -165,6 +165,44 @@ def to_numpy(value: Any, dtype: Any = None) -> np.ndarray[Any, Any] | None:
165165
raise TypeError(f"No converter registered for type {value_type}")
166166

167167

168+
def _apply_numpy_dtype_from_type_annotation(array: np.ndarray, target_type: type) -> np.ndarray:
169+
"""Apply dtype conversion to numpy array based on type annotation.
170+
171+
Args:
172+
array: Numpy array to convert
173+
target_type: Type annotation containing dtype information (e.g., npt.NDArray[np.float32])
174+
175+
Returns:
176+
Array with the correct dtype applied
177+
178+
Example:
179+
>>> import numpy.typing as npt
180+
>>> arr = np.array([1.0, 2.0], dtype=np.float64)
181+
>>> NDArrayFloat32 = npt.NDArray[np.float32]
182+
>>> result = _apply_numpy_dtype_from_type_annotation(arr, NDArrayFloat32)
183+
>>> result.dtype == np.float32
184+
True
185+
"""
186+
type_args = get_args(target_type)
187+
# type_args for np.ndarray are typically (shape, dtype)
188+
if len(type_args) >= 2:
189+
# Extract the dtype from numpy.dtype[T]
190+
dtype_generic = type_args[1]
191+
# Check if this is a numpy.dtype generic type
192+
if get_origin(dtype_generic) is np.dtype:
193+
dtype_args = get_args(dtype_generic)
194+
if dtype_args:
195+
try:
196+
target_dtype = dtype_args[0]
197+
# Only convert if the dtype is different
198+
if array.dtype != np.dtype(target_dtype):
199+
return array.astype(target_dtype)
200+
except (AttributeError, TypeError, ValueError):
201+
# If we can't extract or apply dtype, just return the array as-is
202+
pass
203+
return array
204+
205+
168206
def from_polars_data(polars_data: Any, target_type: type) -> Any:
169207
"""Convert polars data to target type.
170208
@@ -189,6 +227,18 @@ def from_polars_data(polars_data: Any, target_type: type) -> Any:
189227
if target_type in _from_polars_converters:
190228
return _from_polars_converters[target_type](polars_data)
191229

230+
# Check if target_type is a generic type (e.g., np.ndarray[Any, np.dtype[np.float32]])
231+
origin_type = get_origin(target_type)
232+
if origin_type is not None and origin_type in _from_polars_converters:
233+
# Handle typed numpy arrays and other generic types
234+
result = _from_polars_converters[origin_type](polars_data)
235+
236+
# For typed numpy arrays, apply the dtype if specified in the type annotation
237+
if origin_type is np.ndarray and result is not None:
238+
result = _apply_numpy_dtype_from_type_annotation(result, target_type)
239+
240+
return result
241+
192242
# Handle Union types (e.g., torch.Tensor | np.ndarray)
193243
# Check if target_type is a Union type (Python 3.10+ style or typing.Union)
194244
is_union = False
@@ -198,38 +248,33 @@ def from_polars_data(polars_data: Any, target_type: type) -> Any:
198248
if isinstance(target_type, types.UnionType):
199249
is_union = True
200250
union_args = target_type.__args__
201-
else:
202-
# Check for typing.Union (older syntax: Union[A, B])
203-
try:
204-
from typing import get_args, get_origin
205251

206-
if get_origin(target_type) is Union:
207-
is_union = True
208-
union_args = get_args(target_type)
209-
except Exception as e:
210-
logger.error(f"Error handling Union type: {e}")
252+
# Check for typing.Union (older syntax: Union[A, B])
253+
if get_origin(target_type) is Union:
254+
is_union = True
255+
union_args = get_args(target_type)
211256

212257
if is_union and union_args:
213258
return _convert_union_types(union_args=union_args, polars_data=polars_data, target_type=target_type)
214259
raise TypeError(f"No converter registered for type {target_type}")
215260

216261

217262
def _convert_union_types(union_args: tuple[type], polars_data: Any, target_type: type) -> Any:
218-
if types.NoneType in union_args:
219-
# Handle optional types in union (e.g. A | None) when Polars data is None
220-
if polars_data is None:
221-
return None
222-
223-
union_args = tuple(arg for arg in union_args if arg is not types.NoneType)
224-
225-
# For non-optional Union types, try each type in the union until one succeeds
226-
for union_type in union_args:
227-
if union_type in _from_polars_converters:
228-
try:
229-
return _from_polars_converters[union_type](polars_data)
230-
except KeyError:
231-
# If conversion fails, try the next type in the union
232-
continue
263+
if types.NoneType in union_args and polars_data is None:
264+
return None
265+
266+
non_none_args = tuple(arg for arg in union_args if arg is not types.NoneType)
267+
268+
# Try each type in the union until one succeeds
269+
for union_type in non_none_args:
270+
# Try to convert using the union type (which might be generic)
271+
try:
272+
return from_polars_data(polars_data, union_type)
273+
except (KeyError, TypeError):
274+
# If conversion fails, try the next type in the union
275+
continue
276+
277+
# If all conversions failed, raise TypeError
233278
raise TypeError(f"No converter registered for type {target_type}")
234279

235280

tests/unit/experimental/test_type_registry.py

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -325,3 +325,187 @@ def test_points_converter_functionality():
325325
result = to_numpy(points_obj)
326326
assert isinstance(result, np.ndarray)
327327
np.testing.assert_array_equal(result, np.array([[10.0, 20.0, 2.0], [30.0, 40.0, 1.0]]))
328+
329+
330+
def test_typed_numpy_array_basic():
331+
"""Test basic typed numpy array conversion from Polars data."""
332+
import numpy.typing as npt
333+
import polars as pl
334+
335+
# Test Float32 typed array
336+
NDArrayFloat32 = npt.NDArray[np.float32]
337+
df = pl.DataFrame({"data": [[0.8, 0.9]]}, schema={"data": pl.List(pl.Float32)})
338+
result = from_polars_data(df["data"][0], NDArrayFloat32)
339+
340+
assert isinstance(result, np.ndarray)
341+
assert result.dtype == np.float32
342+
np.testing.assert_array_almost_equal(result, np.array([0.8, 0.9], dtype=np.float32))
343+
344+
# Test Int32 typed array
345+
NDArrayInt32 = npt.NDArray[np.int32]
346+
df = pl.DataFrame({"data": [[1, 2, 3]]}, schema={"data": pl.List(pl.Int32)})
347+
result = from_polars_data(df["data"][0], NDArrayInt32)
348+
349+
assert isinstance(result, np.ndarray)
350+
assert result.dtype == np.int32
351+
np.testing.assert_array_equal(result, np.array([1, 2, 3], dtype=np.int32))
352+
353+
# Test Float64 typed array
354+
NDArrayFloat64 = npt.NDArray[np.float64]
355+
df = pl.DataFrame({"data": [[1.5, 2.5]]}, schema={"data": pl.List(pl.Float64)})
356+
result = from_polars_data(df["data"][0], NDArrayFloat64)
357+
358+
assert isinstance(result, np.ndarray)
359+
assert result.dtype == np.float64
360+
np.testing.assert_array_almost_equal(result, np.array([1.5, 2.5], dtype=np.float64))
361+
362+
363+
def test_typed_numpy_array_dtype_conversion():
364+
"""Test that typed numpy arrays trigger dtype conversion when needed."""
365+
import numpy.typing as npt
366+
import polars as pl
367+
368+
# Test conversion from float64 to float32
369+
NDArrayFloat32 = npt.NDArray[np.float32]
370+
df = pl.DataFrame({"data": [[1.0, 2.0]]}, schema={"data": pl.List(pl.Float64)})
371+
result = from_polars_data(df["data"][0], NDArrayFloat32)
372+
373+
assert result.dtype == np.float32, f"Expected float32 but got {result.dtype}"
374+
np.testing.assert_array_almost_equal(result, np.array([1.0, 2.0], dtype=np.float32))
375+
376+
# Test conversion from int64 to int32
377+
NDArrayInt32 = npt.NDArray[np.int32]
378+
df = pl.DataFrame({"data": [[10, 20]]}, schema={"data": pl.List(pl.Int64)})
379+
result = from_polars_data(df["data"][0], NDArrayInt32)
380+
381+
assert result.dtype == np.int32, f"Expected int32 but got {result.dtype}"
382+
np.testing.assert_array_equal(result, np.array([10, 20], dtype=np.int32))
383+
384+
385+
def test_typed_numpy_array_optional():
386+
"""Test optional typed numpy arrays (Type | None)."""
387+
import numpy.typing as npt
388+
import polars as pl
389+
390+
NDArrayFloat32 = npt.NDArray[np.float32]
391+
OptionalFloat32 = NDArrayFloat32 | None if sys.version_info >= (3, 10) else Optional[NDArrayFloat32]
392+
393+
# Test with None
394+
result = from_polars_data(None, OptionalFloat32)
395+
assert result is None
396+
397+
# Test with actual data
398+
df = pl.DataFrame({"data": [[0.8, 0.9]]}, schema={"data": pl.List(pl.Float32)})
399+
result = from_polars_data(df["data"][0], OptionalFloat32)
400+
401+
assert isinstance(result, np.ndarray)
402+
assert result.dtype == np.float32
403+
np.testing.assert_array_almost_equal(result, np.array([0.8, 0.9], dtype=np.float32))
404+
405+
406+
def test_typed_numpy_array_preserves_dtype():
407+
"""Test that typed numpy arrays preserve dtype from Polars when types match."""
408+
import numpy.typing as npt
409+
import polars as pl
410+
411+
# When Polars dtype matches the type annotation, no conversion should occur
412+
NDArrayFloat32 = npt.NDArray[np.float32]
413+
df = pl.DataFrame({"data": [[0.5, 0.7]]}, schema={"data": pl.List(pl.Float32)})
414+
result = from_polars_data(df["data"][0], NDArrayFloat32)
415+
416+
assert result.dtype == np.float32
417+
# Values should be exact (no float precision loss)
418+
np.testing.assert_array_equal(result, np.array([0.5, 0.7], dtype=np.float32))
419+
420+
421+
def test_typed_numpy_array_various_dtypes():
422+
"""Test typed numpy arrays with various numpy dtypes."""
423+
import numpy.typing as npt
424+
import polars as pl
425+
426+
# Test uint8
427+
NDArrayUInt8 = npt.NDArray[np.uint8]
428+
df = pl.DataFrame({"data": [[1, 2, 3]]}, schema={"data": pl.List(pl.UInt8)})
429+
result = from_polars_data(df["data"][0], NDArrayUInt8)
430+
assert result.dtype == np.uint8
431+
432+
# Test int64
433+
NDArrayInt64 = npt.NDArray[np.int64]
434+
df = pl.DataFrame({"data": [[100, 200]]}, schema={"data": pl.List(pl.Int64)})
435+
result = from_polars_data(df["data"][0], NDArrayInt64)
436+
assert result.dtype == np.int64
437+
438+
# Test uint16
439+
NDArrayUInt16 = npt.NDArray[np.uint16]
440+
df = pl.DataFrame({"data": [[1000, 2000]]}, schema={"data": pl.List(pl.UInt16)})
441+
result = from_polars_data(df["data"][0], NDArrayUInt16)
442+
assert result.dtype == np.uint16
443+
444+
445+
def test_typed_numpy_array_helper_function():
446+
"""Test the _apply_numpy_dtype_from_type_annotation helper function directly."""
447+
import numpy.typing as npt
448+
449+
from datumaro.experimental.type_registry import _apply_numpy_dtype_from_type_annotation
450+
451+
# Test dtype conversion
452+
NDArrayFloat32 = npt.NDArray[np.float32]
453+
arr = np.array([1.0, 2.0], dtype=np.float64)
454+
result = _apply_numpy_dtype_from_type_annotation(arr, NDArrayFloat32)
455+
assert result.dtype == np.float32
456+
457+
# Test no conversion when dtype already matches
458+
arr_f32 = np.array([1.0, 2.0], dtype=np.float32)
459+
result = _apply_numpy_dtype_from_type_annotation(arr_f32, NDArrayFloat32)
460+
assert result.dtype == np.float32
461+
462+
# Test with generic np.ndarray (should not convert)
463+
arr_f64 = np.array([1.0, 2.0], dtype=np.float64)
464+
result = _apply_numpy_dtype_from_type_annotation(arr_f64, np.ndarray)
465+
assert result.dtype == np.float64 # Should remain unchanged
466+
467+
468+
def test_typed_numpy_array_round_trip():
469+
"""Test round-trip conversion: numpy -> polars -> typed numpy."""
470+
import numpy.typing as npt
471+
import polars as pl
472+
473+
NDArrayFloat32 = npt.NDArray[np.float32]
474+
475+
# Original typed array
476+
original = np.array([0.8, 0.95, 0.87], dtype=np.float32)
477+
478+
# Convert to polars-compatible format
479+
from datumaro.experimental.type_registry import to_numpy
480+
481+
polars_ready = to_numpy(original, pl.Float32)
482+
483+
# Create polars series
484+
series = pl.Series("scores", [polars_ready], dtype=pl.List(pl.Float32))
485+
486+
# Extract back from polars
487+
polars_data = series[0]
488+
489+
# Convert back to typed numpy array
490+
result = from_polars_data(polars_data, NDArrayFloat32)
491+
492+
# Verify dtype and values are preserved
493+
assert result.dtype == np.float32
494+
np.testing.assert_array_almost_equal(original, result)
495+
496+
497+
def test_typed_numpy_array_multidimensional():
498+
"""Test typed numpy arrays with multidimensional data."""
499+
import numpy.typing as npt
500+
import polars as pl
501+
502+
NDArrayInt32 = npt.NDArray[np.int32]
503+
504+
# Test with nested lists (2D array)
505+
# Note: Polars List type is for 1D arrays, so we test with flattened data
506+
df = pl.DataFrame({"data": [[10, 15, 30, 35]]}, schema={"data": pl.List(pl.Int32)})
507+
result = from_polars_data(df["data"][0], NDArrayInt32)
508+
509+
assert result.dtype == np.int32
510+
assert result.shape == (4,)
511+
np.testing.assert_array_equal(result, np.array([10, 15, 30, 35], dtype=np.int32))

0 commit comments

Comments
 (0)