@@ -56,7 +56,7 @@ def __getitem__(self, i):
5656 return self .features [i ], self .targets [i ]
5757
5858
59- def train (dataloader , input_shape , output_shape , weight_decay , lr , epochs , autocast , device , seed ):
59+ def train (dataloader , input_shape , output_shape , weight_decay , lr , epochs , amp , device , seed ):
6060 torch .manual_seed (seed )
6161 model = torch .nn .Linear (input_shape , output_shape )
6262 devices = [x for x in range (torch .cuda .device_count ())]
@@ -81,7 +81,7 @@ def train(dataloader, input_shape, output_shape, weight_decay, lr, epochs, autoc
8181 scheduler (step )
8282
8383 optimizer .zero_grad ()
84- with autocast ():
84+ with torch . autocast (device , enabled = amp ):
8585 pred = model (x )
8686 loss = criterion (pred , y )
8787
@@ -107,14 +107,14 @@ def train(dataloader, input_shape, output_shape, weight_decay, lr, epochs, autoc
107107 return model
108108
109109
110- def infer (model , dataloader , autocast , device ):
110+ def infer (model , dataloader , amp , device ):
111111 true , pred = [], []
112112 with torch .no_grad ():
113113 for x , y in tqdm (dataloader ):
114114 x = x .to (device )
115115 y = y .to (device )
116116
117- with autocast ():
117+ with torch . autocast (device , enabled = amp ):
118118 logits = model (x )
119119
120120 pred .append (logits .cpu ())
@@ -125,12 +125,12 @@ def infer(model, dataloader, autocast, device):
125125 return logits , target
126126
127127
128- def find_peak (wd_list , idxs , train_loader , val_loader , input_shape , output_shape , lr , epochs , autocast , device , verbose , seed ):
128+ def find_peak (wd_list , idxs , train_loader , val_loader , input_shape , output_shape , lr , epochs , amp , device , verbose , seed ):
129129 best_wd_idx , max_acc = 0 , 0
130130 for idx in idxs :
131131 weight_decay = wd_list [idx ]
132- model = train (train_loader , input_shape , output_shape , weight_decay , lr , epochs , autocast , device , seed )
133- logits , target = infer (model , val_loader , autocast , device )
132+ model = train (train_loader , input_shape , output_shape , weight_decay , lr , epochs , amp , device , seed )
133+ logits , target = infer (model , val_loader , amp , device )
134134 acc1 , = accuracy (logits .float (), target .float (), topk = (1 ,))
135135 if verbose :
136136 print (f"Valid accuracy with weight_decay { weight_decay } : { acc1 } " )
@@ -150,7 +150,6 @@ def evaluate(model, train_dataloader, dataloader, fewshot_k, batch_size, num_wor
150150 os .mkdir (feature_dir )
151151
152152 featurizer = Featurizer (model , normalize ).cuda ()
153- autocast = torch .cuda .amp .autocast if amp else suppress
154153 if not os .path .exists (os .path .join (feature_dir , 'targets_train.pt' )):
155154 # now we have to cache the features
156155 devices = [x for x in range (torch .cuda .device_count ())]
@@ -168,7 +167,7 @@ def evaluate(model, train_dataloader, dataloader, fewshot_k, batch_size, num_wor
168167 for images , target in tqdm (loader ):
169168 images = images .to (device )
170169
171- with autocast ():
170+ with torch . autocast (device , enabled = amp ):
172171 feature = featurizer (images )
173172
174173 features .append (feature .cpu ())
@@ -270,11 +269,11 @@ def evaluate(model, train_dataloader, dataloader, fewshot_k, batch_size, num_wor
270269 wd_list = np .logspace (- 6 , 2 , num = 97 ).tolist ()
271270 wd_list_init = np .logspace (- 6 , 2 , num = 7 ).tolist ()
272271 wd_init_idx = [i for i , val in enumerate (wd_list ) if val in wd_list_init ]
273- peak_idx = find_peak (wd_list , wd_init_idx , feature_train_loader , feature_val_loader , input_shape , output_shape , lr , epochs , autocast , device , verbose , seed )
272+ peak_idx = find_peak (wd_list , wd_init_idx , feature_train_loader , feature_val_loader , input_shape , output_shape , lr , epochs , amp , device , verbose , seed )
274273 step_span = 8
275274 while step_span > 0 :
276275 left , right = max (peak_idx - step_span , 0 ), min (peak_idx + step_span , len (wd_list )- 1 )
277- peak_idx = find_peak (wd_list , [left , peak_idx , right ], feature_train_loader , feature_val_loader , input_shape , output_shape , lr , epochs , autocast , device , verbose , seed )
276+ peak_idx = find_peak (wd_list , [left , peak_idx , right ], feature_train_loader , feature_val_loader , input_shape , output_shape , lr , epochs , amp , device , verbose , seed )
278277 step_span //= 2
279278 best_wd = wd_list [peak_idx ]
280279 if fewshot_k < 0 :
@@ -288,8 +287,8 @@ def evaluate(model, train_dataloader, dataloader, fewshot_k, batch_size, num_wor
288287 best_wd = 0
289288 train_loader = feature_train_loader
290289
291- final_model = train (train_loader , input_shape , output_shape , best_wd , lr , epochs , autocast , device , seed )
292- logits , target = infer (final_model , feature_test_loader , autocast , device )
290+ final_model = train (train_loader , input_shape , output_shape , best_wd , lr , epochs , amp , device , seed )
291+ logits , target = infer (final_model , feature_test_loader , amp , device )
293292 pred = logits .argmax (axis = 1 )
294293
295294 # measure accuracy
0 commit comments