33import torch .nn .functional as F
44
55
6- class DSAugmentation :
6+ class DSA :
77
88 def __init__ (self , params : dict , seed : int = - 1 , aug_mode : str = 'S' ):
99 self .params = params
@@ -30,7 +30,7 @@ def set_seed_DiffAug(self):
3030 def rand_scale (self , x ):
3131 # x>1, max scale
3232 # sx, sy: (0, +oo), 1: orignial size, 0.5: enlarge 2 times
33- ratio = self .params ["ratio_scale " ]
33+ ratio = self .params ["scale " ]
3434 self .set_seed_DiffAug ()
3535 sx = torch .rand (x .shape [0 ]) * (ratio - 1.0 / ratio ) + 1.0 / ratio
3636 self .set_seed_DiffAug ()
@@ -45,7 +45,7 @@ def rand_scale(self, x):
4545 return x
4646
4747 def rand_rotate (self , x ): # [-180, 180], 90: anticlockwise 90 degree
48- ratio = self .params ["ratio_rotate " ]
48+ ratio = self .params ["rotate " ]
4949 self .set_seed_DiffAug ()
5050 theta = (torch .rand (x .shape [0 ]) - 0.5 ) * 2 * ratio / 180 * float (np .pi )
5151 theta = [[[torch .cos (theta [i ]), torch .sin (- theta [i ]), 0 ],
@@ -58,7 +58,7 @@ def rand_rotate(self, x): # [-180, 180], 90: anticlockwise 90 degree
5858 return x
5959
6060 def rand_flip (self , x ):
61- prob = self .params ["prob_flip " ]
61+ prob = self .params ["flip " ]
6262 self .set_seed_DiffAug ()
6363 randf = torch .rand (x .size (0 ), 1 , 1 , 1 , device = x .device )
6464 if self .params ["siamese" ]: # Siamese augmentation:
@@ -99,7 +99,7 @@ def rand_color(self, x):
9999
100100 def rand_crop (self , x ):
101101 # The image is padded on its surrounding and then cropped.
102- ratio = self .params ["ratio_crop_pad " ]
102+ ratio = self .params ["crop " ]
103103 shift_x , shift_y = int (x .size (2 ) * ratio + 0.5 ), int (x .size (3 ) * ratio + 0.5 )
104104 self .set_seed_DiffAug ()
105105 translation_x = torch .randint (- shift_x , shift_x + 1 , size = [x .size (0 ), 1 , 1 ], device = x .device )
@@ -120,7 +120,7 @@ def rand_crop(self, x):
120120 return x
121121
122122 def rand_cutout (self , x ):
123- ratio = self .params ["ratio_cutout " ]
123+ ratio = self .params ["cutout " ]
124124 cutout_size = int (x .size (2 ) * ratio + 0.5 ), int (x .size (3 ) * ratio + 0.5 )
125125 self .set_seed_DiffAug ()
126126 offset_x = torch .randint (0 , x .size (2 ) + (1 - cutout_size [0 ] % 2 ), size = [x .size (0 ), 1 , 1 ], device = x .device )
0 commit comments