30
30
31
31
32
32
def update_train_loss (
33
- trainer : "solver.Solver" , loss_dict : Dict [str , float ], batch_size : int
33
+ solver : "solver.Solver" , loss_dict : Dict [str , float ], batch_size : int
34
34
):
35
35
for key in loss_dict :
36
- if key not in trainer .train_output_info :
37
- trainer .train_output_info [key ] = misc .AverageMeter (key , "7.5f" )
38
- trainer .train_output_info [key ].update (float (loss_dict [key ]), batch_size )
39
- if key not in trainer .train_loss_info :
40
- trainer .train_loss_info [key ] = misc .AverageMeter (key , ".5f" )
41
- trainer .train_loss_info [key ].update (float (loss_dict [key ]))
36
+ if key not in solver .train_output_info :
37
+ solver .train_output_info [key ] = misc .AverageMeter (key , "7.5f" )
38
+ solver .train_output_info [key ].update (float (loss_dict [key ]), batch_size )
39
+ if key not in solver .train_loss_info :
40
+ solver .train_loss_info [key ] = misc .AverageMeter (key , ".5f" )
41
+ solver .train_loss_info [key ].update (float (loss_dict [key ]))
42
42
43
43
44
44
def update_eval_loss (
45
- trainer : "solver.Solver" , loss_dict : Dict [str , float ], batch_size : int
45
+ solver : "solver.Solver" , loss_dict : Dict [str , float ], batch_size : int
46
46
):
47
47
for key in loss_dict :
48
- if key not in trainer .eval_output_info :
49
- trainer .eval_output_info [key ] = misc .AverageMeter (key , "7.5f" )
50
- trainer .eval_output_info [key ].update (float (loss_dict [key ]), batch_size )
48
+ if key not in solver .eval_output_info :
49
+ solver .eval_output_info [key ] = misc .AverageMeter (key , "7.5f" )
50
+ solver .eval_output_info [key ].update (float (loss_dict [key ]), batch_size )
51
51
52
52
53
53
def log_train_info (
54
- trainer : "solver.Solver" , batch_size : int , epoch_id : int , iter_id : int
54
+ solver : "solver.Solver" , batch_size : int , epoch_id : int , iter_id : int
55
55
):
56
- lr_msg = f"lr: { trainer .optimizer .get_lr ():.5f} "
56
+ lr_msg = f"lr: { solver .optimizer .get_lr ():.5f} "
57
57
58
58
metric_msg = ", " .join (
59
59
[
60
- f"{ key } : { trainer .train_output_info [key ].avg :.5f} "
61
- for key in trainer .train_output_info
60
+ f"{ key } : { solver .train_output_info [key ].avg :.5f} "
61
+ for key in solver .train_output_info
62
62
]
63
63
)
64
64
65
65
time_msg = ", " .join (
66
- [trainer .train_time_info [key ].mean for key in trainer .train_time_info ]
66
+ [solver .train_time_info [key ].mean for key in solver .train_time_info ]
67
67
)
68
68
69
- ips_msg = f"ips: { batch_size / trainer .train_time_info ['batch_cost' ].avg :.2f} "
70
- if trainer .benchmark_flag :
69
+ ips_msg = f"ips: { batch_size / solver .train_time_info ['batch_cost' ].avg :.2f} "
70
+ if solver .benchmark_flag :
71
71
ips_msg += " samples/s"
72
72
73
73
eta_sec = (
74
- (trainer .epochs - epoch_id + 1 ) * trainer .iters_per_epoch - iter_id
75
- ) * trainer .train_time_info ["batch_cost" ].avg
74
+ (solver .epochs - epoch_id + 1 ) * solver .iters_per_epoch - iter_id
75
+ ) * solver .train_time_info ["batch_cost" ].avg
76
76
eta_msg = f"eta: { str (datetime .timedelta (seconds = int (eta_sec )))} "
77
77
78
- epoch_width = len (str (trainer .epochs ))
79
- iters_width = len (str (trainer .iters_per_epoch ))
78
+ epoch_width = len (str (solver .epochs ))
79
+ iters_width = len (str (solver .iters_per_epoch ))
80
80
log_str = (
81
- f"[Train][Epoch { epoch_id :>{epoch_width }} /{ trainer .epochs } ]"
82
- f"[Iter { iter_id :>{iters_width }} /{ trainer .iters_per_epoch } ] { lr_msg } , "
81
+ f"[Train][Epoch { epoch_id :>{epoch_width }} /{ solver .epochs } ]"
82
+ f"[Iter { iter_id :>{iters_width }} /{ solver .iters_per_epoch } ] { lr_msg } , "
83
83
f"{ metric_msg } , { time_msg } , { ips_msg } , { eta_msg } "
84
84
)
85
- if trainer .benchmark_flag :
85
+ if solver .benchmark_flag :
86
86
max_mem_reserved_msg = (
87
87
f"max_mem_reserved: { device .cuda .max_memory_reserved () // (1 << 20 )} MB"
88
88
)
@@ -94,57 +94,57 @@ def log_train_info(
94
94
95
95
logger .scalar (
96
96
{
97
- "train/lr" : trainer .optimizer .get_lr (),
97
+ "train/lr" : solver .optimizer .get_lr (),
98
98
** {
99
- f"train/{ key } " : trainer .train_output_info [key ].avg
100
- for key in trainer .train_output_info
99
+ f"train/{ key } " : solver .train_output_info [key ].avg
100
+ for key in solver .train_output_info
101
101
},
102
102
},
103
- step = trainer .global_step ,
104
- vdl_writer = trainer .vdl_writer ,
105
- wandb_writer = trainer .wandb_writer ,
106
- tbd_writer = trainer .tbd_writer ,
103
+ step = solver .global_step ,
104
+ vdl_writer = solver .vdl_writer ,
105
+ wandb_writer = solver .wandb_writer ,
106
+ tbd_writer = solver .tbd_writer ,
107
107
)
108
108
109
109
110
110
def log_eval_info (
111
- trainer : "solver.Solver" ,
111
+ solver : "solver.Solver" ,
112
112
batch_size : int ,
113
113
epoch_id : int ,
114
114
iters_per_epoch : int ,
115
115
iter_id : int ,
116
116
):
117
117
metric_msg = ", " .join (
118
118
[
119
- f"{ key } : { trainer .eval_output_info [key ].avg :.5f} "
120
- for key in trainer .eval_output_info
119
+ f"{ key } : { solver .eval_output_info [key ].avg :.5f} "
120
+ for key in solver .eval_output_info
121
121
]
122
122
)
123
123
124
124
time_msg = ", " .join (
125
- [trainer .eval_time_info [key ].mean for key in trainer .eval_time_info ]
125
+ [solver .eval_time_info [key ].mean for key in solver .eval_time_info ]
126
126
)
127
127
128
- ips_msg = f"ips: { batch_size / trainer .eval_time_info ['batch_cost' ].avg :.2f} "
128
+ ips_msg = f"ips: { batch_size / solver .eval_time_info ['batch_cost' ].avg :.2f} "
129
129
130
- eta_sec = (iters_per_epoch - iter_id ) * trainer .eval_time_info ["batch_cost" ].avg
130
+ eta_sec = (iters_per_epoch - iter_id ) * solver .eval_time_info ["batch_cost" ].avg
131
131
eta_msg = f"eta: { str (datetime .timedelta (seconds = int (eta_sec )))} "
132
132
133
- epoch_width = len (str (trainer .epochs ))
133
+ epoch_width = len (str (solver .epochs ))
134
134
iters_width = len (str (iters_per_epoch ))
135
135
logger .info (
136
- f"[Eval][Epoch { epoch_id :>{epoch_width }} /{ trainer .epochs } ]"
136
+ f"[Eval][Epoch { epoch_id :>{epoch_width }} /{ solver .epochs } ]"
137
137
f"[Iter { iter_id :>{iters_width }} /{ iters_per_epoch } ] "
138
138
f"{ metric_msg } , { time_msg } , { ips_msg } , { eta_msg } "
139
139
)
140
140
141
141
logger .scalar (
142
142
{
143
- f"eval/{ key } " : trainer .eval_output_info [key ].avg
144
- for key in trainer .eval_output_info
143
+ f"eval/{ key } " : solver .eval_output_info [key ].avg
144
+ for key in solver .eval_output_info
145
145
},
146
- step = trainer .global_step ,
147
- vdl_writer = trainer .vdl_writer ,
148
- wandb_writer = trainer .wandb_writer ,
149
- tbd_writer = trainer .tbd_writer ,
146
+ step = solver .global_step ,
147
+ vdl_writer = solver .vdl_writer ,
148
+ wandb_writer = solver .wandb_writer ,
149
+ tbd_writer = solver .tbd_writer ,
150
150
)
0 commit comments