Skip to content

Commit 1bbdec3

Browse files
committed
Fix sample validation for complex types
1 parent 3f3287b commit 1bbdec3

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

src/datumaro/experimental/dataset.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from functools import cache
1111
from typing import TYPE_CHECKING, Annotated, Any, Generic, TypeGuard, Union, cast, get_args, get_origin, get_type_hints
1212

13+
import numpy as np
1314
import polars as pl
1415
from typing_extensions import TypeVar, dataclass_transform
1516

@@ -86,7 +87,7 @@ def _validate_attribute_type(self, expected_type: Any, value: Any) -> bool:
8687
collections.abc.Callable,
8788
}:
8889
return callable(value)
89-
return isinstance(value, expected_type)
90+
return isinstance(value, origin or expected_type)
9091

9192
@classmethod
9293
@cache

tests/unit/experimental/test_sample.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any
66

77
import numpy as np
8+
import numpy.typing as npt
89
import polars as pl
910
import pytest
1011

@@ -17,6 +18,7 @@
1718
bbox_field,
1819
image_field,
1920
image_info_field,
21+
score_field,
2022
)
2123
from datumaro.experimental.fields.images import image_path_field
2224
from datumaro.experimental.schema import Schema, Semantic
@@ -185,3 +187,11 @@ class ExtendedSample(BaseSample):
185187
assert len(extended_schema.attributes) == 3
186188
assert "image_info" in extended_schema.attributes
187189
assert "image_info" not in base_schema.attributes
190+
191+
192+
def test_sample_with_is_list():
193+
class MySample(Sample):
194+
confidence: npt.NDArray[np.float32] | None = score_field(dtype=pl.Float32(), is_list=True)
195+
196+
# Assert that sample can be created without validation errors
197+
MySample(confidence=np.array([0.8]))

0 commit comments

Comments
 (0)