File tree Expand file tree Collapse file tree 1 file changed +16
-0
lines changed Expand file tree Collapse file tree 1 file changed +16
-0
lines changed Original file line number Diff line number Diff line change @@ -530,6 +530,22 @@ class ZeroMeanUnitVarianceKwargs(ProcessingKwargs):
530530 eps : Annotated [float , Interval (gt = 0 , le = 0.1 )] = 1e-6
531531 """epsilon for numeric stability: `out = (tensor - mean) / (std + eps)`."""
532532
533+ @model_validator (mode = "after" )
534+ def validate_axes_based_on_mode (self )-> Self :
535+ mode = self .mode
536+ axes = self .axes
537+
538+ if mode == "per_sample" :
539+ # In 'per_sample' mode, ensure batch axis is not included in 'axes'
540+ if axes and 'batch' in axes :
541+ raise ValueError ("Batch axis should not be included in 'axes' for 'per_sample' mode." )
542+ elif mode == "per_dataset" :
543+ # In 'per_dataset' mode, check if batch axis handling is appropriate
544+ # Implement any specific logic you need here
545+ if axes and 'batch' not in axes :
546+ raise ValueError ("Batch axis must be included in 'axes' for 'per_dataset' mode." )
547+ return self
548+
533549
534550class ZeroMeanUnitVariance (ProcessingBase ):
535551 """Subtract mean and divide by variance."""
You can’t perform that action at this time.
0 commit comments