Skip to content

Commit a946b1f

Browse files
committed
attempt at mode check in one class
1 parent 7b86d31 commit a946b1f

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

bioimageio/spec/model/v0_5.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,22 @@ class ZeroMeanUnitVarianceKwargs(ProcessingKwargs):
530530
eps: Annotated[float, Interval(gt=0, le=0.1)] = 1e-6
531531
"""epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
532532

533+
@model_validator(mode="after")
534+
def validate_axes_based_on_mode(self)-> Self:
535+
mode = self.mode
536+
axes = self.axes
537+
538+
if mode == "per_sample":
539+
# In 'per_sample' mode, ensure batch axis is not included in 'axes'
540+
if axes and 'batch' in axes:
541+
raise ValueError("Batch axis should not be included in 'axes' for 'per_sample' mode.")
542+
elif mode == "per_dataset":
543+
# In 'per_dataset' mode, check if batch axis handling is appropriate
544+
# Implement any specific logic you need here
545+
if axes and 'batch' not in axes:
546+
raise ValueError("Batch axis must be included in 'axes' for 'per_dataset' mode.")
547+
return self
548+
533549

534550
class ZeroMeanUnitVariance(ProcessingBase):
535551
"""Subtract mean and divide by variance."""

0 commit comments

Comments
 (0)