Skip to content

Commit 665fb2b

Browse files
committed
Reduce augmentation probabilities to fix underfitting
1 parent 0e86a49 commit 665fb2b

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

XPointMLTest.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -381,8 +381,8 @@ def _apply_augmentation(self, all_data, mask):
381381
return all_data, mask
382382

383383
# 1. Random rotation (0, 90, 180, 270 degrees)
384-
# 75% chance to apply rotation
385-
if self.rng.random() < 0.75:
384+
# 50% chance to apply rotation
385+
if self.rng.random() < 0.50:
386386
k = self.rng.integers(1, 4) # 1, 2, or 3 (90°, 180°, 270°)
387387
all_data = torch.rot90(all_data, k=k, dims=(-2, -1))
388388
mask = torch.rot90(mask, k=k, dims=(-2, -1))
@@ -397,9 +397,9 @@ def _apply_augmentation(self, all_data, mask):
397397
all_data = torch.flip(all_data, dims=(-2,))
398398
mask = torch.flip(mask, dims=(-2,))
399399

400-
# 4. Add Gaussian noise (30% chance)
400+
# 4. Add Gaussian noise (10% chance)
401401
# Small noise helps prevent overfitting to exact pixel values
402-
if self.rng.random() < 0.3:
402+
if self.rng.random() < 0.1:
403403
noise_std = self.rng.uniform(0.005, 0.02)
404404
noise = torch.randn_like(all_data) * noise_std
405405
all_data = all_data + noise
@@ -413,9 +413,9 @@ def _apply_augmentation(self, all_data, mask):
413413
mean = all_data[c].mean()
414414
all_data[c] = contrast * (all_data[c] - mean) + mean + brightness
415415

416-
# 6. Cutout/Random erasing (20% chance)
416+
# 6. Cutout/Random erasing (5% chance)
417417
# Prevents model from relying too heavily on specific spatial features
418-
if self.rng.random() < 0.2:
418+
if self.rng.random() < 0.05:
419419
h, w = all_data.shape[-2:]
420420
cutout_size = int(min(h, w) * self.rng.uniform(0.1, 0.25))
421421
if cutout_size > 0:

0 commit comments

Comments
 (0)