diff --git a/XPointMLTest.py b/XPointMLTest.py index 3027485..696462c 100644 --- a/XPointMLTest.py +++ b/XPointMLTest.py @@ -404,18 +404,19 @@ def _apply_augmentation(self, all_data, mask): noise = torch.randn_like(all_data) * noise_std all_data = all_data + noise - # 5. Random brightness/contrast adjustment per channel (30% chance) - # Helps model become invariant to intensity variations + # 5. Random brightness/contrast adjustment (30% chance) + # CHANGED: Applied globally across channels to preserve physical relationships + # (e.g., keeping the derivative relationship between psi and B fields) if self.rng.random() < 0.3: - for c in range(all_data.shape[0]): - brightness = self.rng.uniform(-0.1, 0.1) - contrast = self.rng.uniform(0.9, 1.1) - mean = all_data[c].mean() - all_data[c] = contrast * (all_data[c] - mean) + mean + brightness + brightness = self.rng.uniform(-0.1, 0.1) + contrast = self.rng.uniform(0.9, 1.1) + # Apply same transformation to all channels + mean = all_data.mean(dim=(-2, -1), keepdim=True) + all_data = contrast * (all_data - mean) + mean + brightness # 6. Cutout/Random erasing (20% chance) # Prevents model from relying too heavily on specific spatial features - if self.rng.random() < 0.2: + if self.rng.random() < 0.20: h, w = all_data.shape[-2:] cutout_size = int(min(h, w) * self.rng.uniform(0.1, 0.25)) if cutout_size > 0: