@@ -404,14 +404,15 @@ def _apply_augmentation(self, all_data, mask):
404404 noise = torch .randn_like (all_data ) * noise_std
405405 all_data = all_data + noise
406406
407- # 5. Random brightness/contrast adjustment per channel (30% chance)
408- # Helps model become invariant to intensity variations
407+ # 5. Random brightness/contrast adjustment (30% chance)
408+ # CHANGED: Applied globally across channels to preserve physical relationships
409+ # (e.g., keeping the derivative relationship between psi and B fields)
409410 if self .rng .random () < 0.3 :
410- for c in range ( all_data . shape [ 0 ]):
411- brightness = self .rng .uniform (- 0.1 , 0 .1 )
412- contrast = self . rng . uniform ( 0.9 , 1.1 )
413- mean = all_data [ c ] .mean ()
414- all_data [ c ] = contrast * (all_data [ c ] - mean ) + mean + brightness
411+ brightness = self . rng . uniform ( - 0.1 , 0.1 )
412+ contrast = self .rng .uniform (0.9 , 1 .1 )
413+ # Apply same transformation to all channels
414+ mean = all_data .mean (dim = ( - 2 , - 1 ), keepdim = True )
415+ all_data = contrast * (all_data - mean ) + mean + brightness
415416
416417 # 6. Cutout/Random erasing (5% chance)
417418 # Prevents model from relying too heavily on specific spatial features
0 commit comments