Skip to content

Commit ff3abb9

Browse files
authored
fix: check for transforms before checking for matching coord sys name, also docstrings (#1559)
1 parent 42381b6 commit ff3abb9

File tree

2 files changed

+147
-33
lines changed

2 files changed

+147
-33
lines changed

src/aind_data_schema/utils/validators.py

Lines changed: 87 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,18 @@ def _time_validation_recurse_helper(data, acquisition_start_time, acquisition_en
106106

107107

108108
def _recurse_helper(data, **kwargs):
109-
"""Helper function for recursive_axis_order_check: recurse calls for lists and objects only"""
109+
"""Helper function for recursive coordinate system validation.
110+
111+
Recursively processes lists and objects, calling recursive_coord_system_check
112+
on each element or attribute.
113+
114+
Parameters
115+
----------
116+
data : Any
117+
The data structure to process recursively
118+
**kwargs
119+
Keyword arguments passed to recursive_coord_system_check
120+
"""
110121
if isinstance(data, list):
111122
for item in data:
112123
recursive_coord_system_check(item, **kwargs)
@@ -122,42 +133,92 @@ def _recurse_helper(data, **kwargs):
122133

123134

124135
def _system_check_helper(data, coordinate_system_name: Optional[str], axis_count: Optional[int]):
125-
"""Helper function to raise errors if the coordinate_system_name or axis_count don't match"""
126-
object_type = getattr(data, "object_type", type(data).__name__)
136+
"""Helper function to validate coordinate system requirements for objects with transforms.
127137
128-
if not coordinate_system_name or not axis_count:
129-
raise ValueError(
130-
f"CoordinateSystem is required when a Transform or Coordinate is present (object_type: {object_type})"
131-
)
138+
Only validates coordinate system requirements if the object contains transform components
139+
(Translation, Rotation, or Scale). Objects without transforms don't require coordinate systems.
132140
133-
if data.coordinate_system_name not in coordinate_system_name:
134-
raise ValueError(
135-
f"System name mismatch for {object_type}, expected {coordinate_system_name}, "
136-
f"found {data.coordinate_system_name}"
137-
)
141+
Parameters
142+
----------
143+
data : object
144+
The object to validate
145+
coordinate_system_name : Optional[str]
146+
Expected coordinate system name (can be None if no transforms present)
147+
axis_count : Optional[int]
148+
Expected number of axes in the coordinate system (can be None if no transforms present)
149+
150+
Raises
151+
------
152+
ValueError
153+
If coordinate system is required but missing, system name doesn't match,
154+
or transform field length doesn't match axis count
155+
"""
156+
# First check if this object has any transform components
157+
has_transforms = False
158+
transform_components = []
159+
object_type = getattr(data, "object_type", type(data).__name__)
138160

139-
# Check lengths of subfields based on class types
140161
if hasattr(data, "__dict__"):
141162
for attr_name, attr_value in data.__dict__.items():
142163
# Check if the attribute's class name is one of the AXIS_TYPES
143164
if hasattr(attr_value, "__class__") and attr_value.__class__.__name__ in AXIS_TYPES:
144-
# Construct the field name by converting the class name to lowercase
145-
field_name = attr_value.__class__.__name__.lower()
146-
sub_data = getattr(data, field_name, None)
147-
# Check if the object has the corresponding field and if it's a list with correct length
148-
if sub_data and hasattr(sub_data, field_name):
149-
field_value = getattr(sub_data, field_name)
150-
if len(field_value) != axis_count:
151-
raise ValueError(
152-
f"Axis count mismatch for {object_type}, expected {axis_count} axes, "
153-
f"but found {len(field_value)}"
154-
)
165+
has_transforms = True
166+
transform_components.append(attr_value)
167+
168+
# Only require coordinate system if there are transforms present
169+
if has_transforms:
170+
if not coordinate_system_name or not axis_count:
171+
raise ValueError(
172+
f"CoordinateSystem is required when a Transform or Coordinate is present (object_type: {object_type})"
173+
)
174+
175+
if data.coordinate_system_name not in coordinate_system_name:
176+
raise ValueError(
177+
f"System name mismatch for {object_type}, expected {coordinate_system_name}, "
178+
f"found {data.coordinate_system_name}"
179+
)
180+
181+
# Check lengths of transform fields match axis count
182+
for transform_component in transform_components:
183+
field_name = transform_component.__class__.__name__.lower()
184+
sub_data = getattr(data, field_name, None)
185+
# Check if the object has the corresponding field and if it's a list with correct length
186+
if sub_data and hasattr(sub_data, field_name):
187+
field_value = getattr(sub_data, field_name)
188+
if len(field_value) != axis_count:
189+
raise ValueError(
190+
f"Axis count mismatch for {object_type}, expected {axis_count} axes, "
191+
f"but found {len(field_value)}"
192+
)
155193

156194

157195
def recursive_coord_system_check(data, coordinate_system_name: Optional[str], axis_count: Optional[int]):
158-
"""Recursively check fields, see if they are Coordinates and check if they match a List[values]
196+
"""Recursively validate coordinate system requirements for objects with transforms.
159197
160-
Note that we just need to check if the axes all show up, not necessarily in matching order
198+
Traverses the data structure and validates that objects with transform components
199+
(Translation, Rotation, Scale) have the correct coordinate system name and that
200+
transform field lengths match the expected axis count.
201+
202+
Parameters
203+
----------
204+
data : Any
205+
The data structure to validate recursively
206+
coordinate_system_name : Optional[str]
207+
Expected coordinate system name (can be None if no transforms present)
208+
axis_count : Optional[int]
209+
Expected number of axes in the coordinate system (can be None if no transforms present)
210+
211+
Raises
212+
------
213+
ValueError
214+
If coordinate system is required but missing, system name doesn't match,
215+
or transform field length doesn't match axis count
216+
217+
Notes
218+
-----
219+
Objects without transform components are not required to have coordinate systems.
220+
When a new coordinate_system is encountered in the data, it overrides the provided
221+
coordinate_system_name and axis_count for subsequent validation.
161222
"""
162223

163224
if not data:

tests/test_utils_validators.py

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -215,20 +215,73 @@ def test_recursive_coord_system_check_with_axis_count_mismatch(self):
215215
self.assertIn("Axis count mismatch", str(context.exception))
216216

217217
def test_recursive_coord_system_check_with_missing_coordinate_system(self):
218-
"""Test recursive_coord_system_check with missing coordinate system"""
218+
"""Test recursive_coord_system_check with missing coordinate system for object WITH transforms"""
219219

220-
class MockData(BaseModel):
221-
"""Test class"""
222-
223-
coordinate_system_name: str
224-
225-
data = MockData(coordinate_system_name=self.coordinate_system_name)
220+
# Object with transforms should still require coordinate system
221+
data = TranslationWrapper(
222+
coordinate_system_name=self.coordinate_system_name, translation=Translation(translation=[0.5, 1])
223+
)
226224

227225
with self.assertRaises(ValueError) as context:
228226
recursive_coord_system_check(data, None, axis_count=0)
229227

230228
self.assertIn("CoordinateSystem is required", str(context.exception))
231229

230+
def test_recursive_coord_system_check_object_without_transforms(self):
231+
"""Test recursive_coord_system_check with object without transforms (should not require coordinate system)"""
232+
233+
class ObjectWithoutTransforms(DataModel):
234+
"""Object without any transform components"""
235+
236+
coordinate_system_name: str
237+
some_field: str
238+
239+
data = ObjectWithoutTransforms(coordinate_system_name=self.coordinate_system_name, some_field="test_value")
240+
241+
# Should not raise any exception even with None coordinate_system_name and axis_count
242+
recursive_coord_system_check(data, None, axis_count=0)
243+
244+
def test_system_check_helper_object_without_transforms(self):
245+
"""Test _system_check_helper with object without transforms (should not require coordinate system)"""
246+
247+
class ObjectWithoutTransforms(DataModel):
248+
"""Object without any transform components"""
249+
250+
coordinate_system_name: str
251+
some_field: str
252+
253+
data = ObjectWithoutTransforms(coordinate_system_name=self.coordinate_system_name, some_field="test_value")
254+
255+
# Should not raise any exception even with None coordinate_system_name and axis_count
256+
_system_check_helper(data, None, axis_count=0)
257+
258+
def test_mixed_objects_with_and_without_transforms(self):
259+
"""Test with a mix of objects with and without transforms"""
260+
261+
class ObjectWithoutTransforms(DataModel):
262+
"""Object without any transform components"""
263+
264+
coordinate_system_name: str
265+
some_field: str
266+
267+
class ContainerModel(DataModel):
268+
"""Container with mixed objects"""
269+
270+
with_transform: TranslationWrapper
271+
without_transform: ObjectWithoutTransforms
272+
273+
container = ContainerModel(
274+
with_transform=TranslationWrapper(
275+
coordinate_system_name=self.coordinate_system_name, translation=Translation(translation=[0.5, 1])
276+
),
277+
without_transform=ObjectWithoutTransforms(
278+
coordinate_system_name="any_name", some_field="test" # This can be anything since no transforms
279+
),
280+
)
281+
282+
# Should pass validation - only the object with transforms is checked
283+
recursive_coord_system_check(container, self.coordinate_system_name, axis_count=2)
284+
232285

233286
class MockEnum(Enum):
234287
"""Mock Enum for testing"""

0 commit comments

Comments
 (0)