Skip to content

Commit 782030b

Browse files
Add aug doc and reformat aug code
1 parent 3b88e2d commit 782030b

File tree

23 files changed

+216
-64
lines changed

23 files changed

+216
-64
lines changed

configs/Demo_Hard_Label.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@ use_torchvision: false
1515
# data augmentation
1616
data_aug_func: "dsa"
1717
aug_params:
18-
prob_flip: 0.5
19-
ratio_rotate: 15.0
18+
flip: 0.5
19+
rotate: 15.0
2020
saturation: 2.0
2121
brightness: 1.0
2222
contrast: 0.5
23-
ratio_scale: 1.2
24-
ratio_crop_pad: 0.125
25-
ratio_cutout: 0.5
23+
scale: 1.2
24+
crop: 0.125
25+
cutout: 0.5
2626
use_zca: false
2727

2828
custom_train_trans: null

configs/Demo_Soft_Label.yaml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,14 @@ teacher_dir: ./teacher_models
1616
# syntheticdata augmentation
1717
data_aug_func: dsa
1818
aug_params:
19-
prob_flip: 0.5
20-
ratio_rotate: 15.0
19+
flip: 0.5
20+
rotate: 15.0
2121
saturation: 2.0
2222
brightness: 1.0
2323
contrast: 0.5
24-
ratio_scale: 1.2
25-
ratio_crop_pad: 0.125
26-
ratio_cutout: 0.5
24+
scale: 1.2
25+
crop: 0.125
26+
cutout: 0.5
2727
use_zca: true
2828
custom_train_trans: null
2929

dd_ranking/__init__.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import dd_ranking.aug as aug
2-
import dd_ranking.config as config
3-
import dd_ranking.loss as loss
4-
import dd_ranking.metrics as metrics
5-
import dd_ranking.utils as utils
1+
from .aug import DSA, Mixup, Cutmix, ZCAWhitening
2+
from .config import Config
3+
from .loss import KLDivergenceLoss, SoftCrossEntropyLoss
4+
from .metrics import HardScoreEvaluator, SoftLabelEvaluator, GeneralEvaluator
5+
from .utils import get_dataset, build_model, get_convnet, get_lenet, get_resnet, get_vgg, get_alexnet

dd_ranking/aug/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .dsa import DSAugmentation
2-
from .mixup import MixupAugmentation
3-
from .cutmix import CutmixAugmentation
4-
from .zca import ZCAWhiteningAugmentation
1+
from .dsa import DSA
2+
from .mixup import Mixup
3+
from .cutmix import Cutmix
4+
from .zca import ZCAWhitening

dd_ranking/aug/cutmix.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import kornia
44

55

6-
class CutmixAugmentation:
6+
class Cutmix:
77
def __init__(self, params: dict):
8-
self.cutmix_p = params["cutmix_p"]
8+
self.beta = params["beta"]
99

1010
def rand_bbox(self, size, lam):
1111
W = size[2]
@@ -27,7 +27,7 @@ def rand_bbox(self, size, lam):
2727

2828
def cutmix(self, images):
2929
rand_index = torch.randperm(images.size()[0]).to(images.device)
30-
lam = np.random.beta(self.cutmix_p, self.cutmix_p)
30+
lam = np.random.beta(self.beta, self.beta)
3131
bbx1, bby1, bbx2, bby2 = self.rand_bbox(images.size(), lam)
3232

3333
images[:, :, bbx1:bbx2, bby1:bby2] = images[rand_index, :, bbx1:bbx2, bby1:bby2]

dd_ranking/aug/dsa.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch.nn.functional as F
44

55

6-
class DSAugmentation:
6+
class DSA:
77

88
def __init__(self, params: dict, seed: int=-1, aug_mode: str='S'):
99
self.params = params
@@ -30,7 +30,7 @@ def set_seed_DiffAug(self):
3030
def rand_scale(self, x):
3131
# x>1, max scale
3232
# sx, sy: (0, +oo), 1: orignial size, 0.5: enlarge 2 times
33-
ratio = self.params["ratio_scale"]
33+
ratio = self.params["scale"]
3434
self.set_seed_DiffAug()
3535
sx = torch.rand(x.shape[0]) * (ratio - 1.0/ratio) + 1.0/ratio
3636
self.set_seed_DiffAug()
@@ -45,7 +45,7 @@ def rand_scale(self, x):
4545
return x
4646

4747
def rand_rotate(self, x): # [-180, 180], 90: anticlockwise 90 degree
48-
ratio = self.params["ratio_rotate"]
48+
ratio = self.params["rotate"]
4949
self.set_seed_DiffAug()
5050
theta = (torch.rand(x.shape[0]) - 0.5) * 2 * ratio / 180 * float(np.pi)
5151
theta = [[[torch.cos(theta[i]), torch.sin(-theta[i]), 0],
@@ -58,7 +58,7 @@ def rand_rotate(self, x): # [-180, 180], 90: anticlockwise 90 degree
5858
return x
5959

6060
def rand_flip(self, x):
61-
prob = self.params["prob_flip"]
61+
prob = self.params["flip"]
6262
self.set_seed_DiffAug()
6363
randf = torch.rand(x.size(0), 1, 1, 1, device=x.device)
6464
if self.params["siamese"]: # Siamese augmentation:
@@ -99,7 +99,7 @@ def rand_color(self, x):
9999

100100
def rand_crop(self, x):
101101
# The image is padded on its surrounding and then cropped.
102-
ratio = self.params["ratio_crop_pad"]
102+
ratio = self.params["crop"]
103103
shift_x, shift_y = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
104104
self.set_seed_DiffAug()
105105
translation_x = torch.randint(-shift_x, shift_x + 1, size=[x.size(0), 1, 1], device=x.device)
@@ -120,7 +120,7 @@ def rand_crop(self, x):
120120
return x
121121

122122
def rand_cutout(self, x):
123-
ratio = self.params["ratio_cutout"]
123+
ratio = self.params["cutout"]
124124
cutout_size = int(x.size(2) * ratio + 0.5), int(x.size(3) * ratio + 0.5)
125125
self.set_seed_DiffAug()
126126
offset_x = torch.randint(0, x.size(2) + (1 - cutout_size[0] % 2), size=[x.size(0), 1, 1], device=x.device)

dd_ranking/aug/mixup.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
import kornia
44

55

6-
class MixupAugmentation:
6+
class Mixup:
77
def __init__(self, params: dict):
8-
self.mixup_p = params["mixup_p"]
8+
self.lambda_ = params["lambda"]
99

1010
def mixup(self, images):
1111
rand_index = torch.randperm(images.size()[0]).to(images.device)
12-
lam = np.random.beta(self.mixup_p, self.mixup_p)
12+
lam = np.random.beta(self.lambda_, self.lambda_)
1313

1414
mixed_images = lam * images + (1 - lam) * images[rand_index]
1515
return mixed_images

dd_ranking/aug/zca.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import kornia
22

33

4-
class ZCAWhiteningAugmentation:
4+
class ZCAWhitening:
55
def __init__(self, params: dict):
66
self.transform = kornia.enhance.ZCAWhitening()
77

dd_ranking/metrics/general.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from dd_ranking.utils import set_seed, get_optimizer, get_lr_scheduler
1515
from dd_ranking.utils import train_one_epoch, validate
1616
from dd_ranking.loss import SoftCrossEntropyLoss, KLDivergenceLoss
17-
from dd_ranking.aug import DSAugmentation, MixupAugmentation, CutmixAugmentation, ZCAWhiteningAugmentation
17+
from dd_ranking.aug import DSA, Mixup, Cutmix, ZCAWhitening
1818
from dd_ranking.config import Config
1919

2020

@@ -147,11 +147,11 @@ def __init__(self,
147147
if data_aug_func is None:
148148
self.aug_func = None
149149
elif data_aug_func == 'dsa':
150-
self.aug_func = DSA_Augmentation(aug_params)
150+
self.aug_func = DSA(aug_params)
151151
elif data_aug_func == 'mixup':
152-
self.aug_func = Mixup_Augmentation(aug_params)
152+
self.aug_func = Mixup(aug_params)
153153
elif data_aug_func == 'cutmix':
154-
self.aug_func = Cutmix_Augmentation(aug_params)
154+
self.aug_func = Cutmix(aug_params)
155155
else:
156156
raise ValueError(f"Invalid data augmentation function: {data_aug_func}")
157157

dd_ranking/metrics/hard_label.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from dd_ranking.utils import build_model, get_pretrained_model_path
1313
from dd_ranking.utils import TensorDataset, get_random_images, get_dataset, save_results
1414
from dd_ranking.utils import set_seed, train_one_epoch, validate, get_optimizer, get_lr_scheduler
15-
from dd_ranking.aug import DSAugmentation, MixupAugmentation, CutmixAugmentation, ZCAWhiteningAugmentation
15+
from dd_ranking.aug import DSA, Mixup, Cutmix, ZCAWhitening
1616
from dd_ranking.config import Config
1717

1818

@@ -82,13 +82,13 @@ def __init__(self, config: Config=None, dataset: str='CIFAR10', real_data_path:
8282
self.device = device
8383

8484
if data_aug_func == 'dsa':
85-
self.aug_func = DSA_Augmentation(aug_params)
85+
self.aug_func = DSA(aug_params)
8686
elif data_aug_func == 'zca':
87-
self.aug_func = ZCA_Whitening_Augmentation(aug_params)
87+
self.aug_func = ZCAWhitening(aug_params)
8888
elif data_aug_func == 'mixup':
89-
self.aug_func = Mixup_Augmentation(aug_params)
89+
self.aug_func = Mixup(aug_params)
9090
elif data_aug_func == 'cutmix':
91-
self.aug_func = Cutmix_Augmentation(aug_params)
91+
self.aug_func = Cutmix(aug_params)
9292
else:
9393
self.aug_func = None
9494

0 commit comments

Comments
 (0)