1+ """ Mixup and Cutmix
2+
3+ Papers:
4+ mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)
5+
6+ CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)
7+
8+ Code Reference:
9+ CutMix: https://github.com/clovaai/CutMix-PyTorch
10+
11+ Hacked together by Ross Wightman
12+ """
13+
114import numpy as np
215import torch
16+ import math
17+ from enum import IntEnum
18+
19+
20+ class MixupMode (IntEnum ):
21+ MIXUP = 0
22+ CUTMIX = 1
23+ RANDOM = 2
24+
25+ @classmethod
26+ def from_str (cls , value ):
27+ return cls [value .upper ()]
328
429
530def one_hot (x , num_classes , on_value = 1. , off_value = 0. , device = 'cuda' ):
@@ -12,7 +37,7 @@ def mixup_target(target, num_classes, lam=1., smoothing=0.0, device='cuda'):
1237 on_value = 1. - smoothing + off_value
1338 y1 = one_hot (target , num_classes , on_value = on_value , off_value = off_value , device = device )
1439 y2 = one_hot (target .flip (0 ), num_classes , on_value = on_value , off_value = off_value , device = device )
15- return lam * y1 + (1. - lam )* y2
40+ return y1 * lam + y2 * (1. - lam )
1641
1742
1843def mixup_batch (input , target , alpha = 0.2 , num_classes = 1000 , smoothing = 0.1 , disable = False ):
@@ -24,28 +49,167 @@ def mixup_batch(input, target, alpha=0.2, num_classes=1000, smoothing=0.1, disab
2449 return input , target
2550
2651
52+ def rand_bbox (size , ratio ):
53+ H , W = size [- 2 :]
54+ ratio = max (min (ratio , 0.8 ), 0.2 )
55+ cut_h , cut_w = int (H * ratio ), int (W * ratio )
56+ cy , cx = np .random .randint (H ), np .random .randint (W )
57+ yl , yh = np .clip (cy - cut_h // 2 , 0 , H ), np .clip (cy + cut_h // 2 , 0 , H )
58+ xl , xh = np .clip (cx - cut_w // 2 , 0 , W ), np .clip (cx + cut_w // 2 , 0 , W )
59+ return yl , yh , xl , xh
60+
61+
62+ def cutmix_batch (input , target , alpha = 0.2 , num_classes = 1000 , smoothing = 0.1 , disable = False ):
63+ lam = 1.
64+ if not disable :
65+ lam = np .random .beta (alpha , alpha )
66+ if lam != 1 :
67+ ratio = math .sqrt (1. - lam )
68+ yl , yh , xl , xh = rand_bbox (input .size (), ratio )
69+ input [:, :, yl :yh , xl :xh ] = input .flip (0 )[:, :, yl :yh , xl :xh ]
70+ target = mixup_target (target , num_classes , lam , smoothing )
71+ return input , target
72+
73+
74+ def _resolve_mode (mode ):
75+ mode = MixupMode .from_str (mode ) if isinstance (mode , str ) else mode
76+ if mode == MixupMode .RANDOM :
77+ mode = MixupMode (np .random .rand () > 0.5 )
78+ return mode # will be one of cutmix or mixup
79+
80+
81+ def mix_batch (
82+ input , target , alpha = 0.2 , num_classes = 1000 , smoothing = 0.1 , disable = False , mode = MixupMode .MIXUP ):
83+ mode = _resolve_mode (mode )
84+ if mode == MixupMode .CUTMIX :
85+ return mixup_batch (input , target , alpha , num_classes , smoothing , disable )
86+ else :
87+ return cutmix_batch (input , target , alpha , num_classes , smoothing , disable )
88+
89+
2790class FastCollateMixup :
91+ """Fast Collate Mixup that applies different params to each element + flipped pair
2892
29- def __init__ (self , mixup_alpha = 1. , label_smoothing = 0.1 , num_classes = 1000 ):
93+ NOTE once experiments are done, one of the three variants will remain with this class name
94+ """
95+ def __init__ (self , mixup_alpha = 1. , label_smoothing = 0.1 , num_classes = 1000 , mode = MixupMode .MIXUP ):
3096 self .mixup_alpha = mixup_alpha
3197 self .label_smoothing = label_smoothing
3298 self .num_classes = num_classes
99+ self .mode = MixupMode .from_str (mode ) if isinstance (mode , str ) else mode
33100 self .mixup_enabled = True
101+ self .correct_lam = False # correct lambda based on clipped area for cutmix
102+
103+ def _do_mix (self , tensor , batch ):
104+ batch_size = len (batch )
105+ lam_out = torch .ones (batch_size )
106+ for i in range (batch_size // 2 ):
107+ j = batch_size - i - 1
108+ lam = 1.
109+ if self .mixup_enabled :
110+ lam = np .random .beta (self .mixup_alpha , self .mixup_alpha )
111+
112+ if _resolve_mode (self .mode ) == MixupMode .CUTMIX :
113+ mixed_i , mixed_j = batch [i ][0 ].astype (np .float32 ), batch [j ][0 ].astype (np .float32 )
114+ ratio = math .sqrt (1. - lam )
115+ if lam != 1 :
116+ yl , yh , xl , xh = rand_bbox (tensor .size (), ratio )
117+ mixed_i [:, yl :yh , xl :xh ] = batch [j ][0 ][:, yl :yh , xl :xh ].astype (np .float32 )
118+ mixed_j [:, yl :yh , xl :xh ] = batch [i ][0 ][:, yl :yh , xl :xh ].astype (np .float32 )
119+ if self .correct_lam :
120+ lam_corrected = (yh - yl ) * (xh - xl ) / (tensor .shape [- 2 ] * tensor .shape [- 1 ])
121+ lam_out [i ] -= lam_corrected
122+ lam_out [j ] -= lam_corrected
123+ else :
124+ lam_out [i ] = lam
125+ lam_out [j ] = lam
126+ else :
127+ mixed_i = batch [i ][0 ].astype (np .float32 ) * lam + batch [j ][0 ].astype (np .float32 ) * (1 - lam )
128+ mixed_j = batch [j ][0 ].astype (np .float32 ) * lam + batch [i ][0 ].astype (np .float32 ) * (1 - lam )
129+ lam_out [i ] = lam
130+ lam_out [j ] = lam
131+ np .round (mixed_i , out = mixed_i )
132+ np .round (mixed_j , out = mixed_j )
133+ tensor [i ] += torch .from_numpy (mixed_i .astype (np .uint8 ))
134+ tensor [j ] += torch .from_numpy (mixed_j .astype (np .uint8 ))
135+ return lam_out
34136
35137 def __call__ (self , batch ):
36138 batch_size = len (batch )
139+ assert batch_size % 2 == 0 , 'Batch size should be even when using this'
140+ tensor = torch .zeros ((batch_size , * batch [0 ][0 ].shape ), dtype = torch .uint8 )
141+ lam = self ._do_mix (tensor , batch )
142+ target = torch .tensor ([b [1 ] for b in batch ], dtype = torch .int64 )
143+ target = mixup_target (target , self .num_classes , lam .unsqueeze (1 ), self .label_smoothing , device = 'cpu' )
144+
145+ return tensor , target
146+
147+
148+ class FastCollateMixupElementwise (FastCollateMixup ):
149+ """Fast Collate Mixup that applies different params to each batch element
150+
151+ NOTE this is for experimentation, may remove at some point
152+ """
153+ def __init__ (self , mixup_alpha = 1. , label_smoothing = 0.1 , num_classes = 1000 , mode = MixupMode .MIXUP ):
154+ super (FastCollateMixupElementwise , self ).__init__ (mixup_alpha , label_smoothing , num_classes , mode )
155+
156+ def _do_mix (self , tensor , batch ):
157+ batch_size = len (batch )
158+ lam_out = torch .ones (batch_size )
159+ for i in range (batch_size ):
160+ lam = 1.
161+ if self .mixup_enabled :
162+ lam = np .random .beta (self .mixup_alpha , self .mixup_alpha )
163+
164+ if _resolve_mode (self .mode ) == MixupMode .CUTMIX :
165+ mixed = batch [i ][0 ].astype (np .float32 )
166+ ratio = math .sqrt (1. - lam )
167+ if lam != 1 :
168+ yl , yh , xl , xh = rand_bbox (tensor .size (), ratio )
169+ mixed [:, yl :yh , xl :xh ] = batch [batch_size - i - 1 ][0 ][:, yl :yh , xl :xh ].astype (np .float32 )
170+ if self .correct_lam :
171+ lam_out [i ] -= (yh - yl ) * (xh - xl ) / (tensor .shape [- 2 ] * tensor .shape [- 1 ])
172+ else :
173+ lam_out [i ] = lam
174+ else :
175+ mixed = batch [i ][0 ].astype (np .float32 ) * lam + \
176+ batch [batch_size - i - 1 ][0 ].astype (np .float32 ) * (1 - lam )
177+ lam_out [i ] = lam
178+ np .round (mixed , out = mixed )
179+ tensor [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
180+ return lam_out
181+
182+
183+ class FastCollateMixupBatchwise (FastCollateMixup ):
184+ """Fast Collate Mixup that applies same params to whole batch
185+
186+ NOTE this is for experimentation, may remove at some point
187+ """
188+
189+ def __init__ (self , mixup_alpha = 1. , label_smoothing = 0.1 , num_classes = 1000 , mode = MixupMode .MIXUP ):
190+ super (FastCollateMixupBatchwise , self ).__init__ (mixup_alpha , label_smoothing , num_classes , mode )
191+
192+ def _do_mix (self , tensor , batch ):
193+ batch_size = len (batch )
194+ lam_out = torch .ones (batch_size )
37195 lam = 1.
196+ cutmix = _resolve_mode (self .mode ) == MixupMode .CUTMIX
38197 if self .mixup_enabled :
39198 lam = np .random .beta (self .mixup_alpha , self .mixup_alpha )
199+ if cutmix and self .correct_lam :
200+ ratio = math .sqrt (1. - lam )
201+ yl , yh , xl , xh = rand_bbox (batch [0 ][0 ].shape , ratio )
202+ lam = 1 - (yh - yl ) * (xh - xl ) / (tensor .shape [- 2 ] * tensor .shape [- 1 ])
40203
41- target = torch .tensor ([b [1 ] for b in batch ], dtype = torch .int64 )
42- target = mixup_target (target , self .num_classes , lam , self .label_smoothing , device = 'cpu' )
43-
44- tensor = torch .zeros ((batch_size , * batch [0 ][0 ].shape ), dtype = torch .uint8 )
45204 for i in range (batch_size ):
46- mixed = batch [i ][0 ].astype (np .float32 ) * lam + \
47- batch [batch_size - i - 1 ][0 ].astype (np .float32 ) * (1 - lam )
205+ if cutmix :
206+ mixed = batch [i ][0 ].astype (np .float32 )
207+ if lam != 1 :
208+ mixed [:, yl :yh , xl :xh ] = batch [batch_size - i - 1 ][0 ][:, yl :yh , xl :xh ].astype (np .float32 )
209+ lam_out [i ] -= (yh - yl ) * (xh - xl ) / (tensor .shape [- 2 ] * tensor .shape [- 1 ])
210+ else :
211+ mixed = batch [i ][0 ].astype (np .float32 ) * lam + \
212+ batch [batch_size - i - 1 ][0 ].astype (np .float32 ) * (1 - lam )
48213 np .round (mixed , out = mixed )
49214 tensor [i ] += torch .from_numpy (mixed .astype (np .uint8 ))
50-
51- return tensor , target
215+ return lam
0 commit comments