Skip to content

Commit 9fbdff2

Browse files
committed
Fix augmentation bug: apply brightness/contrast globally to preserve physical field relationships
1 parent 665fb2b commit 9fbdff2

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

XPointMLTest.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -404,14 +404,15 @@ def _apply_augmentation(self, all_data, mask):
404404
noise = torch.randn_like(all_data) * noise_std
405405
all_data = all_data + noise
406406

407-
# 5. Random brightness/contrast adjustment per channel (30% chance)
408-
# Helps model become invariant to intensity variations
407+
# 5. Random brightness/contrast adjustment (30% chance)
408+
# CHANGED: Applied globally across channels to preserve physical relationships
409+
# (e.g., keeping the derivative relationship between psi and B fields)
409410
if self.rng.random() < 0.3:
410-
for c in range(all_data.shape[0]):
411-
brightness = self.rng.uniform(-0.1, 0.1)
412-
contrast = self.rng.uniform(0.9, 1.1)
413-
mean = all_data[c].mean()
414-
all_data[c] = contrast * (all_data[c] - mean) + mean + brightness
411+
brightness = self.rng.uniform(-0.1, 0.1)
412+
contrast = self.rng.uniform(0.9, 1.1)
413+
# Apply same transformation to all channels
414+
mean = all_data.mean(dim=(-2, -1), keepdim=True)
415+
all_data = contrast * (all_data - mean) + mean + brightness
415416

416417
# 6. Cutout/Random erasing (5% chance)
417418
# Prevents model from relying too heavily on specific spatial features

0 commit comments

Comments
 (0)