@@ -15,14 +15,14 @@ def parse_args():
1515
1616 # hyperparameters sent by the client (same flag names as estimator hyperparameters)
1717 p .add_argument ("--seed" , type = int , default = 42 )
18- p .add_argument ("--batch-size" , type = int , default = 32 )
19- p .add_argument ("--num-epochs-phase1" , type = int , default = 3 )
18+ p .add_argument ("--batch-size" , type = int , default = 512 )
19+ p .add_argument ("--num-epochs-phase1" , type = int , default = 2 )
2020 p .add_argument ("--num-epochs-phase2" , type = int , default = 2 )
21- p .add_argument ("--lr-head" , type = float , default = 1e -3 )
22- p .add_argument ("--lr-backbone" , type = float , default = 1e -4 )
21+ p .add_argument ("--lr-head" , type = float , default = 16e -3 )
22+ p .add_argument ("--lr-backbone" , type = float , default = 16e -4 )
2323 p .add_argument ("--patience" , type = int , default = 3 )
24- p .add_argument ("--num-workers" , type = int , default = 4 )
25- p .add_argument ("--img-size" , type = int , default = 224 )
24+ p .add_argument ("--num-workers" , type = int , default = 2 )
25+ p .add_argument ("--img-size" , type = int , default = 384 )
2626
2727 # other variables
2828 p .add_argument ("--wandb-project" , type = str , default = "food101-classifier" )
@@ -119,7 +119,7 @@ def main():
119119 [0.229 ,0.224 ,0.225 ])
120120 ])
121121 test_tfms = transforms .Compose ([
122- transforms .Resize (256 ), # shrink so short edge=256
122+ transforms .Resize (512 ), # shrink so short edge=256
123123 transforms .CenterCrop (cfg .img_size ), # take middle window
124124 transforms .ToTensor (),
125125 transforms .Normalize ([0.485 ,0.456 ,0.406 ],
@@ -192,6 +192,7 @@ def build_model(num_classes: int) -> nn.Module:
192192 else :
193193 scaler = None
194194
195+ step_counters = {'train' : 0 , 'val' : 0 }
195196 def epoch_loop (phase : str ,
196197 model : nn .Module ,
197198 loader : DataLoader ,
@@ -200,6 +201,8 @@ def epoch_loop (phase: str,
200201 ) -> Tuple [float , float ]:
201202 is_train = optimizer is not None
202203 model .train () if is_train else model .eval ()
204+ # Use autocast if CUDA, else normal FP32
205+ context = autocast ('cuda' ) if torch .cuda .is_available () else nullcontext () # increase GPU efficiency with autocast if available
203206
204207 run_loss , run_correct , imgs_processed = 0.0 , 0 , 0
205208 t0 = time .time ()
@@ -210,8 +213,6 @@ def epoch_loop (phase: str,
210213
211214 optimizer .zero_grad (set_to_none = True ) if is_train else None # saves a bit of GPU memory when setting to `none`
212215
213- # Use autocast if CUDA, else normal FP32
214- context = autocast ('cuda' ) if torch .cuda .is_available () else nullcontext () # increase GPU efficiency with autocast if available
215216 with context :
216217 outputs = model (x )
217218 loss = criterion (outputs , y )
@@ -243,7 +244,8 @@ def epoch_loop (phase: str,
243244 if phase in ["train" , "val" ]:
244245 wandb .log ({
245246 f"{ phase } /batch_loss" : loss .item (),
246- })
247+ }, step = step_counters [phase ])
248+ step_counters [phase ] += 1
247249
248250 if torch .cuda .is_available ():
249251 torch .cuda .synchronize () # CPU waits until GPU finishes. More accurate dt.
@@ -256,14 +258,15 @@ def epoch_loop (phase: str,
256258 loss_scale = scaler .get_scale ()
257259 peak_mem_MB = torch .cuda .max_memory_allocated ()/ 1024 ** 2
258260
259- # logging
260- print (f"{ phase :5} | loss { epoch_loss :.4f} | acc { epoch_acc :.4f} | { dt :5.1f} s | { throughput :7.1f} samples/s" )
261- if is_train and scaler :
262- print (f"loss_scale={ loss_scale :.0f} peak_mem={ peak_mem_MB :.0f} MB" )
263- torch .cuda .reset_peak_memory_stats ()
264-
265- # wandb: epoch logging (train & val only)
261+ # epoch logging: (train & val only)
266262 if phase in ["train" , "val" ]:
263+ # print statements
264+ print (f"{ phase :5} | loss { epoch_loss :.4f} | acc { epoch_acc :.4f} | { dt :5.1f} s | { throughput :7.1f} samples/s" )
265+ if is_train and scaler :
266+ print (f"loss_scale={ loss_scale :.0f} peak_mem={ peak_mem_MB :.0f} MB" )
267+ torch .cuda .reset_peak_memory_stats ()
268+
269+ # wandb
267270 metrics = {
268271 f"{ phase } /epoch_loss" : epoch_loss ,
269272 f"{ phase } /epoch_acc" : epoch_acc ,
@@ -275,7 +278,7 @@ def epoch_loop (phase: str,
275278 f"{ phase } /loss_scale" : loss_scale ,
276279 f"{ phase } /peak_mem_MB" : peak_mem_MB ,
277280 })
278- wandb .log (metrics )
281+ wandb .log (metrics , step = step_counters [ phase ] - 1 ) # ensures logging at the same step as the last batch of that epoch
279282 return epoch_loss , epoch_acc
280283
281284 # checkpoint helper
@@ -287,7 +290,7 @@ def save_ckpt(state: Dict, filename: str, model_dir: str) -> None:
287290
288291 # ---------- Training and Evaluation ----------
289292 # phase 1: feature extraction (freeze backbone, train only the new head)
290- print ("Phase 1: feature extraction" )
293+ print ("\n Phase 1: feature extraction" )
291294
292295 optimizer = optim .Adam (model .classifier [1 ].parameters (), lr = cfg .lr_head )
293296
@@ -329,7 +332,7 @@ def save_ckpt(state: Dict, filename: str, model_dir: str) -> None:
329332 break
330333
331334 # Phase 2: fine-tune (unfreeze backbone, train whole model at lower LR)
332- print ("Phase 2: fine-tune" )
335+ print ("\n Phase 2: fine-tune" )
333336
334337 # unfreeze backbone
335338 for p in model .parameters ():
@@ -365,9 +368,10 @@ def save_ckpt(state: Dict, filename: str, model_dir: str) -> None:
365368 save_ckpt (ckpt , "best_backbone.pth" , cfg .model_dir )
366369
367370 # final test
371+ print ("\n Final test" )
368372 model .eval ()
369373 _ , test_acc = epoch_loop ("test" , model , test_dl )
370- print (f"Final Test Acc: { test_acc :.4f} " )
374+ print (f"\n Final Test Acc: { test_acc :.4f} " )
371375 wandb .summary ["test_acc" ] = test_acc
372376
373377 # save final model
0 commit comments