33from typing import Iterable
44
55import torch
6-
7- import util .misc as misc
86import util .lr_sched as lr_sched
7+ import util .misc as misc
98
109
10+ def train_one_epoch (
11+ model : torch .nn .Module ,
12+ data_loader : Iterable ,
13+ optimizer : torch .optim .Optimizer ,
14+ device : torch .device ,
15+ epoch : int ,
16+ loss_scaler ,
17+ log_writer = None ,
18+ args = None ,
19+ ):
1120
12- def train_one_epoch (model : torch .nn .Module ,
13- data_loader : Iterable , optimizer : torch .optim .Optimizer ,
14- device : torch .device , epoch : int , loss_scaler ,
15- log_writer = None ,
16- args = None ):
17-
1821 model .train (True )
1922 metric_logger = misc .MetricLogger (delimiter = " " )
20- metric_logger .add_meter ('lr' , misc .SmoothedValue (window_size = 1 , fmt = ' {value:.6f}' ))
21- header = ' Epoch: [{}]' .format (epoch )
23+ metric_logger .add_meter ("lr" , misc .SmoothedValue (window_size = 1 , fmt = " {value:.6f}" ))
24+ header = " Epoch: [{}]" .format (epoch )
2225 print_freq = 10
2326
2427 accum_iter = args .accum_iter
2528
2629 optimizer .zero_grad ()
2730
2831 if log_writer is not None :
29- print ('log_dir: {}' .format (log_writer .log_dir ))
30- for data_iter_step , (examples , labels , example_mask ) in enumerate (metric_logger .log_every (data_loader , print_freq , header )):
32+ print ("log_dir: {}" .format (log_writer .log_dir ))
33+ for data_iter_step , (examples , labels , example_mask ) in enumerate (
34+ metric_logger .log_every (data_loader , print_freq , header )
35+ ):
3136 # we use a per iteration (instead of per epoch) lr scheduler
3237 if data_iter_step % accum_iter == 0 :
3338 lr_sched .adjust_learning_rate (optimizer , data_iter_step / len (data_loader ) + epoch , args )
@@ -43,8 +48,7 @@ def train_one_epoch(model: torch.nn.Module,
4348
4449 loss /= accum_iter
4550
46- loss_scaler (loss , optimizer , parameters = model .parameters (),
47- update_grad = (data_iter_step + 1 ) % accum_iter == 0 )
51+ loss_scaler (loss , optimizer , parameters = model .parameters (), update_grad = (data_iter_step + 1 ) % accum_iter == 0 )
4852 if (data_iter_step + 1 ) % accum_iter == 0 :
4953 optimizer .zero_grad ()
5054
@@ -55,42 +59,49 @@ def train_one_epoch(model: torch.nn.Module,
5559 lr = optimizer .param_groups [0 ]["lr" ]
5660 metric_logger .update (lr = lr )
5761
58- loss_value_reduce = misc .all_reduce_mean (loss_value )
62+ misc .all_reduce_mean (loss_value )
5963 c_loss_value_reduce = misc .all_reduce_mean (c_loss_value )
6064
6165 if log_writer is not None and (data_iter_step + 1 ) % accum_iter == 0 :
62- """ We use epoch_1000x as the x-axis in tensorboard.
66+ """We use epoch_1000x as the x-axis in tensorboard.
6367 This calibrates different curves when batch size changes.
6468 """
6569 epoch_1000x = int ((data_iter_step / len (data_loader ) + epoch ) * 1000 )
66- log_writer .add_scalar (' c_train_loss' , c_loss_value_reduce , epoch_1000x )
67- log_writer .add_scalar ('lr' , lr , epoch_1000x )
70+ log_writer .add_scalar (" c_train_loss" , c_loss_value_reduce , epoch_1000x )
71+ log_writer .add_scalar ("lr" , lr , epoch_1000x )
6872
6973 # gather the stats from all processes
7074 metric_logger .synchronize_between_processes ()
7175 print ("Averaged stats:" , metric_logger )
7276 return {k : meter .global_avg for k , meter in metric_logger .meters .items ()}
7377
7478
75- def val_one_epoch (model : torch .nn .Module ,
76- data_loader : Iterable , optimizer : torch .optim .Optimizer ,
77- device : torch .device , epoch : int , loss_scaler ,
78- log_writer = None ,
79- args = None ):
79+ def val_one_epoch (
80+ model : torch .nn .Module ,
81+ data_loader : Iterable ,
82+ optimizer : torch .optim .Optimizer ,
83+ device : torch .device ,
84+ epoch : int ,
85+ loss_scaler ,
86+ log_writer = None ,
87+ args = None ,
88+ ):
8089 model .eval ()
8190 metric_logger = misc .MetricLogger (delimiter = " " )
82- metric_logger .add_meter ('lr' , misc .SmoothedValue (window_size = 1 , fmt = ' {value:.6f}' ))
83- header = ' Epoch: [{}]' .format (epoch )
91+ metric_logger .add_meter ("lr" , misc .SmoothedValue (window_size = 1 , fmt = " {value:.6f}" ))
92+ header = " Epoch: [{}]" .format (epoch )
8493 print_freq = 10
8594
8695 accum_iter = args .accum_iter
8796
8897 if log_writer is not None :
89- print ('log_dir: {}' .format (log_writer .log_dir ))
90- for data_iter_step , (examples , labels , example_mask ) in enumerate (metric_logger .log_every (data_loader , print_freq , header )):
98+ print ("log_dir: {}" .format (log_writer .log_dir ))
99+ for data_iter_step , (examples , labels , example_mask ) in enumerate (
100+ metric_logger .log_every (data_loader , print_freq , header )
101+ ):
91102
92103 with torch .no_grad ():
93- c_loss = model (examples , labels )
104+ c_loss = model (examples , labels )
94105 loss = c_loss
95106 loss_value = loss .item ()
96107
@@ -105,15 +116,15 @@ def val_one_epoch(model: torch.nn.Module,
105116 lr = optimizer .param_groups [0 ]["lr" ]
106117 metric_logger .update (lr = lr )
107118
108- loss_value_reduce = misc .all_reduce_mean (loss_value )
119+ misc .all_reduce_mean (loss_value )
109120 c_loss_value_reduce = misc .all_reduce_mean (c_loss_value )
110121 if log_writer is not None and (data_iter_step + 1 ) % accum_iter == 0 :
111- """ We use epoch_1000x as the x-axis in tensorboard.
122+ """We use epoch_1000x as the x-axis in tensorboard.
112123 This calibrates different curves when batch size changes.
113124 """
114125 epoch_1000x = int ((data_iter_step / len (data_loader ) + epoch ) * 1000 )
115- log_writer .add_scalar (' c_train_loss' , c_loss_value_reduce , epoch_1000x )
116- log_writer .add_scalar ('lr' , lr , epoch_1000x )
126+ log_writer .add_scalar (" c_train_loss" , c_loss_value_reduce , epoch_1000x )
127+ log_writer .add_scalar ("lr" , lr , epoch_1000x )
117128
118129 # gather the stats from all processes
119130 metric_logger .synchronize_between_processes ()
0 commit comments