Skip to content

Commit 3f3287b

Browse files
authored
Sample and dtype validation (#1966)
This pull request introduces improvements to type safety, validation, and schema handling for experimental dataset fields, with updates across core modules and tests. The main changes include stricter type validation for sample attributes, standardized handling of `dtype` for fields, and updates to tests to reflect these improvements. **Validation and Type Safety Enhancements** * Added a `validate()` method to the `Sample` class to check attribute presence and type correctness against the inferred schema, including support for `Union` and `Callable` types. Validation is now called after initialization. [[1]](diffhunk://#diff-4ac196ddc4dc8e6d33daf684ded18886ff8774fadb8b6cbd4bfa88ca424bb34fL38-R41) [[2]](diffhunk://#diff-4ac196ddc4dc8e6d33daf684ded18886ff8774fadb8b6cbd4bfa88ca424bb34fR51-R90) * Improved attribute type validation logic to correctly handle complex types such as `Union` and `Callable`. **Field `dtype` Standardization** * Updated the `Field` base class to ensure that `dtype` is always a Polars `DataType` instance, converting from a class if necessary and raising an error for invalid types. * All field subclasses now explicitly declare their `dtype` using `field(default_factory=...)` or `field(default=..., init=False)`, ensuring consistent and correct schema generation. [[1]](diffhunk://#diff-3098d9e238cfbc0a11faa65960430fbaadedc13fc8181e6042c44550eceff117R47-R58) [[2]](diffhunk://#diff-3098d9e238cfbc0a11faa65960430fbaadedc13fc8181e6042c44550eceff117R138) [[3]](diffhunk://#diff-25a01cfa1cd0bc6a78e628327b4eb8a1953cfda02cae490190eb57037b68960eR132) [[4]](diffhunk://#diff-25a01cfa1cd0bc6a78e628327b4eb8a1953cfda02cae490190eb57037b68960eR178-R180) [[5]](diffhunk://#diff-25a01cfa1cd0bc6a78e628327b4eb8a1953cfda02cae490190eb57037b68960eR235) [[6]](diffhunk://#diff-25a01cfa1cd0bc6a78e628327b4eb8a1953cfda02cae490190eb57037b68960eR280) [[7]](diffhunk://#diff-2a524f5523ceb4ac2f8edce7fcbff9268da416a550e2166ef455df966a2eb9b0L171-R171) [[8]](diffhunk://#diff-2a524f5523ceb4ac2f8edce7fcbff9268da416a550e2166ef455df966a2eb9b0L243-R243) **Schema and Serialization Improvements** * The `from_dict` method for fields now skips non-init dataclass fields during deserialization, preventing errors from fields that shouldn't be set via the constructor. **Test Suite Updates** * Updated integration and unit tests to use explicit types for callable/image/mask fields, replacing `Any` with more precise unions (e.g., `np.ndarray | Callable[[], np.ndarray]`). This strengthens type checking and reflects the stricter validation logic. [[1]](diffhunk://#diff-ef7ef70a007f9bb0c6cf01021e927c6bbd39fbe5164b830620b9abb4d11ee354L73-R73) [[2]](diffhunk://#diff-ef7ef70a007f9bb0c6cf01021e927c6bbd39fbe5164b830620b9abb4d11ee354L164-R164) [[3]](diffhunk://#diff-ef7ef70a007f9bb0c6cf01021e927c6bbd39fbe5164b830620b9abb4d11ee354L206-R208) [[4]](diffhunk://#diff-ef7ef70a007f9bb0c6cf01021e927c6bbd39fbe5164b830620b9abb4d11ee354L247-R247) [[5]](diffhunk://#diff-ef7ef70a007f9bb0c6cf01021e927c6bbd39fbe5164b830620b9abb4d11ee354L304-R304) [[6]](diffhunk://#diff-ef7ef70a007f9bb0c6cf01021e927c6bbd39fbe5164b830620b9abb4d11ee354L361-R361) [[7]](diffhunk://#diff-ef7ef70a007f9bb0c6cf01021e927c6bbd39fbe5164b830620b9abb4d11ee354L408-R408) [[8]](diffhunk://#diff-ef7ef70a007f9bb0c6cf01021e927c6bbd39fbe5164b830620b9abb4d11ee354L476-R476) [[9]](diffhunk://#diff-ef7ef70a007f9bb0c6cf01021e927c6bbd39fbe5164b830620b9abb4d11ee354L554-R554) [[10]](diffhunk://#diff-ef7ef70a007f9bb0c6cf01021e927c6bbd39fbe5164b830620b9abb4d11ee354L664-R664) [[11]](diffhunk://#diff-ef7ef70a007f9bb0c6cf01021e927c6bbd39fbe5164b830620b9abb4d11ee354L695-R695) [[12]](diffhunk://#diff-ef7ef70a007f9bb0c6cf01021e927c6bbd39fbe5164b830620b9abb4d11ee354L828-R833) * Added a new unit test to verify correct validation and conversion of field `dtype`, including error handling for invalid types. **General Codebase Maintenance** * Added missing imports and minor refactoring for clarity and correctness in test and core files. [[1]](diffhunk://#diff-4ac196ddc4dc8e6d33daf684ded18886ff8774fadb8b6cbd4bfa88ca424bb34fR7-R9) [[2]](diffhunk://#diff-3098d9e238cfbc0a11faa65960430fbaadedc13fc8181e6042c44550eceff117L5-R11) [[3]](diffhunk://#diff-f7d8d115b4530c510a92b0dbc1a1174ca351f0c2767fd8137a2da599fff8b484R7) [[4]](diffhunk://#diff-f7d8d115b4530c510a92b0dbc1a1174ca351f0c2767fd8137a2da599fff8b484R54) [[5]](diffhunk://#diff-e1e501a8f398cfac69ca16099abb40dd54bd830d7a4eeb3c2d765af75c5f6e00R1) [[6]](diffhunk://#diff-ef7ef70a007f9bb0c6cf01021e927c6bbd39fbe5164b830620b9abb4d11ee354L16-R16) These changes collectively improve the robustness and maintainability of the experimental dataset system, especially around schema definition, attribute validation, and field type safety. <!-- Contributing guide: https://github.com/open-edge-platform/datumaro/blob/develop/contributing.md --> <!-- Please add a summary of changes. You may use Copilot to auto-generate the PR description but please consider including any other relevant facts which Copilot may be unaware of (such as design choices and testing procedure). Add references to the relevant issues and pull requests if any like so: Resolves #111 and #222. Depends on #1000 (for series of dependent commits). --> Resolves #1855 ### Checklist <!-- Put an 'x' in all the boxes that apply --> - [x] I have added tests to cover my changes or documented any manual tests. - [ ] I have updated the [documentation](https://github.com/open-edge-platform/datumaro/tree/develop/docs) accordingly --------- Signed-off-by: Jort Bergfeld <[email protected]>
1 parent 366941e commit 3f3287b

27 files changed

+483
-323
lines changed

src/datumaro/experimental/converters/image_converters.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def filter_output_spec(self) -> bool:
106106
name=self.output_image.name,
107107
field=ImageField(
108108
semantic=self.input_image.field.semantic,
109-
dtype=pl.Float32,
109+
dtype=pl.Float32(),
110110
format=self.input_image.field.format,
111111
channels_first=self.output_image.field.channels_first,
112112
),
@@ -161,7 +161,7 @@ def filter_output_spec(self) -> bool:
161161
name=self.output_image.name,
162162
field=ImageField(
163163
semantic=self.input_path.field.semantic,
164-
dtype=pl.UInt8, # Default to UInt8 for loaded images
164+
dtype=pl.UInt8(), # Default to UInt8 for loaded images
165165
format="RGB", # Default to RGB format
166166
channels_first=self.output_image.field.channels_first,
167167
),
@@ -275,7 +275,7 @@ def filter_output_spec(self) -> bool:
275275
name=self.output_image.name,
276276
field=ImageField(
277277
semantic=self.input_bytes.field.semantic,
278-
dtype=pl.UInt8, # Default to UInt8 for decoded images
278+
dtype=pl.UInt8(), # Default to UInt8 for decoded images
279279
format="RGB", # Default to RGB format
280280
channels_first=self.output_image.field.channels_first,
281281
),
@@ -342,7 +342,7 @@ def filter_output_spec(self) -> bool:
342342
name=self.output_image.name,
343343
field=ImageField(
344344
semantic=self.input_callable.field.semantic,
345-
dtype=pl.UInt8, # Default to UInt8 for image data
345+
dtype=pl.UInt8(), # Default to UInt8 for image data
346346
format=self.input_callable.field.format, # Use format from callable field
347347
channels_first=self.output_image.field.channels_first,
348348
),

src/datumaro/experimental/converters/mask_converters.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def apply_conversion_batch(batch_df: pl.DataFrame) -> pl.DataFrame:
153153

154154
return pl.struct(
155155
pl.Series(results_batch_polygons).alias("mask"),
156-
pl.Series(results_batch_shape, dtype=pl.List(pl.Int32)).alias("shape"),
156+
pl.Series(results_batch_shape, dtype=pl.List(pl.Int32())).alias("shape"),
157157
eager=True,
158158
)
159159

@@ -165,7 +165,7 @@ def apply_conversion_batch(batch_df: pl.DataFrame) -> pl.DataFrame:
165165
]
166166
).map_batches(
167167
apply_conversion_batch,
168-
return_dtype=pl.Struct({"mask": pl.List(pl.UInt8), "shape": pl.List(pl.Int32)}),
168+
return_dtype=pl.Struct({"mask": pl.List(pl.UInt8()), "shape": pl.List(pl.Int32())}),
169169
)
170170

171171
return df.with_columns(
@@ -282,7 +282,7 @@ def apply_conversion_batch(batch_df: pl.DataFrame, **kwargs) -> pl.DataFrame: #
282282

283283
return pl.struct(
284284
pl.Series(results_batch_mask).alias("mask"),
285-
pl.Series(results_batch_shape, dtype=pl.List(pl.Int32)).alias("shape"),
285+
pl.Series(results_batch_shape, dtype=pl.List(pl.Int32())).alias("shape"),
286286
eager=True,
287287
)
288288

@@ -294,7 +294,7 @@ def apply_conversion_batch(batch_df: pl.DataFrame, **kwargs) -> pl.DataFrame: #
294294
).map_batches(
295295
apply_conversion_batch,
296296
return_dtype=pl.Struct(
297-
{"mask": pl.List(self.output_instance_mask.field.dtype), "shape": pl.List(pl.Int32)}
297+
{"mask": pl.List(self.output_instance_mask.field.dtype), "shape": pl.List(pl.Int32())}
298298
),
299299
)
300300

src/datumaro/experimental/dataset.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44

55
from __future__ import annotations
66

7+
import collections.abc
78
import types
9+
import typing
810
from functools import cache
911
from typing import TYPE_CHECKING, Annotated, Any, Generic, TypeGuard, Union, cast, get_args, get_origin, get_type_hints
1012

@@ -35,8 +37,8 @@ def __init__(self, **kwargs: Any):
3537
"""Initialize sample with provided attributes."""
3638
for key, value in kwargs.items():
3739
setattr(self, key, value)
38-
3940
self.__post_init__()
41+
self.validate()
4042

4143
def __post_init__(self) -> None:
4244
pass
@@ -46,6 +48,46 @@ def __repr__(self):
4648
fields = ", ".join(f"{key}={getattr(self, key)}" for key in self.__dict__ if not key.startswith("_"))
4749
return f"{self.__class__.__name__}({fields})"
4850

51+
def validate(self) -> None:
52+
"""
53+
Validate the sample's attributes against the inferred schema.
54+
55+
Raises:
56+
ValueError: If required attributes are missing
57+
TypeError: If attribute types do not match the schema
58+
"""
59+
schema = self.__class__.infer_schema() # Cached per class
60+
for name, attr_info in schema.attributes.items():
61+
if name not in self.__dict__:
62+
continue
63+
value = getattr(self, name)
64+
expected_type = attr_info.type
65+
field = attr_info.field
66+
67+
if not self._validate_attribute_type(expected_type, value):
68+
raise TypeError(f"Attribute `{name}` must be of type `{expected_type}`.")
69+
70+
# Custom field validation (if any)
71+
if hasattr(field, "validate"):
72+
field.validate(value)
73+
74+
def _validate_attribute_type(self, expected_type: Any, value: Any) -> bool:
75+
"""
76+
Recursively validate attribute type, handling Union and Callable types.
77+
"""
78+
# Union and Callable types have to be handled separately,
79+
# because isinstance() does not work with Callable types.
80+
origin = get_origin(expected_type)
81+
if origin is Union:
82+
# Check each type in the Union
83+
return any(self._validate_attribute_type(typ, value) for typ in get_args(expected_type))
84+
if origin in {typing.Callable, collections.abc.Callable} or expected_type in {
85+
typing.Callable,
86+
collections.abc.Callable,
87+
}:
88+
return callable(value)
89+
return isinstance(value, expected_type)
90+
4991
@classmethod
5092
@cache
5193
def infer_schema(cls) -> Schema:

src/datumaro/experimental/export_import.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -393,10 +393,10 @@ def _update_dataframe_with_field(
393393
if is_path_field:
394394
if field_name in df.columns:
395395
df = df.drop(field_name)
396-
return df.with_columns(pl.Series(field_name, values, dtype=pl.String))
396+
return df.with_columns(pl.Series(field_name, values, dtype=pl.String()))
397397
if field_name in df.columns:
398398
return df.with_columns(pl.Series(field_name, values))
399-
return df.with_columns(pl.Series(field_name, values, dtype=pl.Object))
399+
return df.with_columns(pl.Series(field_name, values, dtype=pl.Object()))
400400

401401

402402
def _reconstruct_image_fields(
@@ -430,7 +430,7 @@ def _add_missing_object_columns(
430430
"""Add back any object columns that weren't reconstructed from images."""
431431
for col_name in object_columns:
432432
if col_name not in df.columns:
433-
df = df.with_columns(pl.Series(col_name, [None] * len(df), dtype=pl.Object))
433+
df = df.with_columns(pl.Series(col_name, [None] * len(df), dtype=pl.Object()))
434434
return df
435435

436436

src/datumaro/experimental/fields/annotations.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import polars as pl
88

9-
from datumaro.experimental.fields.base import Field, PolarsDataType, Semantic, T, convert_numpy_object_array_to_series
9+
from datumaro.experimental.fields.base import Field, Semantic, T, convert_numpy_object_array_to_series
1010
from datumaro.experimental.type_registry import from_polars_data, to_numpy
1111

1212

@@ -26,7 +26,7 @@ class BBoxField(Field):
2626
"""
2727

2828
semantic: Semantic
29-
dtype: PolarsDataType = field(default_factory=pl.Float32)
29+
dtype: pl.DataType = field(default_factory=pl.Float32)
3030
format: str = "x1y1x2y2"
3131
normalize: bool = False
3232

@@ -95,7 +95,7 @@ class RotatedBBoxField(Field):
9595
"""
9696

9797
semantic: Semantic
98-
dtype: PolarsDataType = field(default_factory=pl.Float32)
98+
dtype: pl.DataType = field(default_factory=pl.Float32)
9999
format: str = "cxcywhr"
100100
normalize: bool = False
101101

@@ -159,7 +159,7 @@ class LabelField(Field):
159159
"""
160160

161161
semantic: Semantic
162-
dtype: PolarsDataType = field(default_factory=pl.UInt8)
162+
dtype: pl.DataType = field(default_factory=pl.UInt8)
163163
multi_label: bool = False # Flag to indicate if this field should handle multi-labels
164164
is_list: bool = False
165165

@@ -217,7 +217,7 @@ class ScoreField(Field):
217217
"""
218218

219219
semantic: Semantic
220-
dtype: PolarsDataType = field(default_factory=pl.Float32)
220+
dtype: pl.DataType = field(default_factory=pl.Float32)
221221
is_list: bool = False
222222

223223
@property
@@ -276,7 +276,7 @@ class PolygonField(Field):
276276
"""
277277

278278
semantic: Semantic
279-
dtype: PolarsDataType = field(default_factory=pl.Float32)
279+
dtype: pl.DataType = field(default_factory=pl.Float32)
280280
format: str = "xy"
281281
normalize: bool = False
282282

@@ -335,7 +335,7 @@ class KeypointsField(Field):
335335
"""
336336

337337
semantic: Semantic
338-
dtype: PolarsDataType = field(default_factory=pl.Float32)
338+
dtype: pl.DataType = field(default_factory=pl.Float32)
339339
normalize: bool = False
340340

341341
def to_polars_schema(self, name: str) -> dict[str, pl.DataType]:
@@ -395,7 +395,7 @@ class EllipseField(Field):
395395
"""
396396

397397
semantic: Semantic
398-
dtype: PolarsDataType = field(default_factory=pl.Float32)
398+
dtype: pl.DataType = field(default_factory=pl.Float32)
399399
format: str = "x1y1x2y2"
400400
normalize: bool = False
401401

src/datumaro/experimental/fields/base.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,13 @@
1212
from dataclasses import fields as dataclass_fields
1313
from dataclasses import is_dataclass
1414
from enum import Flag, auto
15-
from typing import Any, TypeAlias, TypeVar
15+
from typing import Any, TypeVar
1616

1717
import numpy as np
1818
import polars as pl
1919

2020
T = TypeVar("T")
2121

22-
PolarsDataType: TypeAlias = type[pl.DataType] | pl.DataType
23-
2422

2523
class Semantic(Flag):
2624
"""
@@ -48,6 +46,17 @@ class Field:
4846
"""
4947

5048
semantic: Semantic
49+
dtype: pl.DataType
50+
51+
def __post_init__(self):
52+
dtype = getattr(self, "dtype")
53+
if isinstance(dtype, type) and issubclass(dtype, pl.DataType):
54+
raise TypeError(
55+
f"dtype must be a Polars 'DataType' (instance), not a Polars 'DataTypeClass' (type). "
56+
f"Make sure your dtype declaration uses parentheses ({dtype.__name__}() instead of {dtype.__name__})"
57+
)
58+
if not isinstance(dtype, pl.DataType):
59+
raise TypeError(f"dtype must be a Polars 'DataType', got '{dtype.__name__}' instead.")
5160

5261
def to_polars_schema(self, name: str) -> dict[str, pl.DataType]:
5362
"""
@@ -165,6 +174,9 @@ def from_dict(cls, field_dict: dict[str, Any]) -> "Field":
165174
# Use dataclass introspection to get all expected fields
166175
if is_dataclass(field_class):
167176
for dc_field in dataclass_fields(field_class):
177+
if not dc_field.init:
178+
continue # Skip fields that are not in __init__
179+
168180
field_name = dc_field.name
169181

170182
# Skip if not in the serialized data

src/datumaro/experimental/fields/datasets.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# SPDX-License-Identifier: MIT
44
import types
5-
from dataclasses import dataclass
5+
from dataclasses import dataclass, field
66
from enum import Enum, auto
77
from typing import Any, Union, get_args, get_origin
88

@@ -44,6 +44,18 @@ class TileField(Field):
4444
"""
4545

4646
semantic: Semantic
47+
dtype: pl.DataType = field(
48+
default_factory=lambda: pl.Struct(
49+
[
50+
pl.Field("source_sample_idx", pl.Int32()),
51+
pl.Field("x", pl.Int32()),
52+
pl.Field("y", pl.Int32()),
53+
pl.Field("width", pl.Int32()),
54+
pl.Field("height", pl.Int32()),
55+
]
56+
),
57+
init=False,
58+
)
4759

4860
def to_polars_schema(self, name: str) -> dict[str, pl.DataType]:
4961
"""Generate Polars schema for tile information."""
@@ -123,6 +135,7 @@ class SubsetField(Field):
123135

124136
semantic: Semantic
125137
categories: list[str] | None = None
138+
dtype: pl.DataType = field(default_factory=pl.Categorical, init=False)
126139

127140
def to_polars_schema(self, name: str) -> dict[str, pl.DataType]:
128141
"""Generate schema with categorical type for subset values."""
@@ -141,7 +154,7 @@ def to_polars(self, name: str, value: Any) -> dict[str, pl.Series]:
141154
polars_value = str(value)
142155

143156
# Create categorical series with predefined categories if available
144-
return {name: pl.Series(name, [polars_value], dtype=pl.Categorical)}
157+
return {name: pl.Series(name, [polars_value], dtype=pl.Categorical())}
145158

146159
def from_polars(self, name: str, row_index: int, df: pl.DataFrame, target_type: type[T]) -> T:
147160
"""Reconstruct subset value from Polars data.

src/datumaro/experimental/fields/images.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numpy as np
88
import polars as pl
99

10-
from datumaro.experimental.fields.base import Field, PolarsDataType, Semantic, T
10+
from datumaro.experimental.fields.base import Field, Semantic, T
1111
from datumaro.experimental.type_registry import from_polars_data, to_numpy
1212

1313

@@ -26,7 +26,7 @@ class TensorField(Field):
2626
"""
2727

2828
semantic: Semantic
29-
dtype: PolarsDataType = field(default_factory=pl.UInt8)
29+
dtype: pl.DataType = field(default_factory=pl.UInt8)
3030
channels_first: bool = False
3131

3232
def to_polars_schema(self, name: str) -> dict[str, pl.DataType]:
@@ -129,6 +129,7 @@ class ImageBytesField(Field):
129129
"""
130130

131131
semantic: Semantic
132+
dtype: pl.DataType = field(default_factory=pl.Binary, init=False)
132133

133134
def to_polars_schema(self, name: str) -> dict[str, pl.DataType]:
134135
"""Generate schema for image bytes as binary data."""
@@ -174,6 +175,9 @@ class ImageInfoField(Field):
174175
"""
175176

176177
semantic: Semantic
178+
dtype: pl.DataType = field(
179+
default_factory=lambda: pl.Struct([pl.Field("width", pl.Int32()), pl.Field("height", pl.Int32())]), init=False
180+
)
177181

178182
def to_polars_schema(self, name: str) -> dict[str, pl.DataType]:
179183
return {
@@ -228,6 +232,7 @@ class ImagePathField(Field):
228232
"""
229233

230234
semantic: Semantic
235+
dtype: pl.DataType = field(default_factory=pl.String, init=False)
231236

232237
def to_polars_schema(self, name: str) -> dict[str, pl.DataType]:
233238
"""Generate schema for string path column."""
@@ -272,10 +277,11 @@ class ImageCallableField(Field):
272277

273278
semantic: Semantic
274279
format: str = "RGB"
280+
dtype: pl.DataType = field(default_factory=pl.Object, init=False)
275281

276282
def to_polars_schema(self, name: str) -> dict[str, pl.DataType]:
277283
"""Return schema with Object type to store callable."""
278-
return {name: pl.Object}
284+
return {name: pl.Object()}
279285

280286
def to_polars(self, name: str, value: callable) -> dict[str, pl.Series]:
281287
"""Store callable as Object in Polars series."""

0 commit comments

Comments
 (0)