-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathutils.py
More file actions
267 lines (213 loc) · 8.31 KB
/
utils.py
File metadata and controls
267 lines (213 loc) · 8.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
# Copyright 2025-2026 Muhammad Nizwa. All rights reserved.
import torch
from typing import Dict, Optional
def is_improvement(value: float, best: float, mode: str, epsilon: float) -> bool:
"""
Check if a metric value represents an improvement over the best value.
"""
if mode == "min":
return (best - value) > epsilon
else:
return (value - best) > epsilon
class TrainCheckpoint:
"""
Tracks the best metric value seen during training and automatically
saves the model state whenever the metric improves beyond the epsilon threshold.
Args:
filepath (str): Path where checkpoint files will be saved
mode (str): Optimization mode, either "min" or "max"
epsilon (float): Minimum improvement threshold
Raises:
AssertionError: If mode is not "min" or "max"
"""
def __init__(self, filepath: str, mode: str = "min", epsilon: float = 0.0) -> None:
assert mode in {"min", "max"}
self.filepath = filepath
self.mode = mode
self.epsilon = epsilon
self.reinit()
def reinit(self) -> None:
self.best_value = float("inf") if self.mode == "min" else -float("inf")
def step(self, value: float, checkpoint_dict: Dict) -> None:
"""
Args:
value (float): Current metric value to evaluate
checkpoint_dict (Dict): Dictionary containing model state, optimizer state, etc. to save
"""
if is_improvement(value, self.best_value, self.mode, self.epsilon):
self.best_value = value
torch.save(checkpoint_dict, self.filepath)
class EarlyStopping:
"""
Monitors a metric value and triggers early stopping when the metric
fails to improve for a specified number of epochs (patience).
Args:
patience (int): Number of epochs with no improvement before stopping
epsilon (float): Minimum improvement threshold
mode (str): Optimization mode, either "min" or "max"
Raises:
AssertionError: If mode is not "min" or "max"
"""
def __init__(self, patience: int, epsilon: float = 1e-4, mode: str = "min") -> None:
assert mode in {"min", "max"}
self.patience = patience
self.epsilon = epsilon
self.mode = mode
self.reinit()
def reinit(self) -> None:
self.best_value = float("inf") if self.mode == "min" else -float("inf")
self.wait = 0
def step(self, value: float) -> bool:
"""
Args:
value (float): Current metric value to evaluate
Returns:
bool: True if training should stop, False otherwise
"""
if is_improvement(value, self.best_value, self.mode, self.epsilon):
print(f"* metrics improved from {self.best_value:.6f} to {value:.6f}")
self.best_value = value
self.wait = 0
return False
else:
print(f"metrics did not improve from {self.best_value:.6f}")
self.wait += 1
early_stop = self.wait >= self.patience
if early_stop:
print(f"Early stopping, no improvement in the last {self.patience} epochs")
return early_stop
class ReduceLROnPlateau:
"""
Learning rate scheduling that reduces the optimizer's learning rate
by a factor when a metric fails to improve for a specified number of epochs.
Includes cooldown period to avoid thrashing and minimum learning rate boundary.
Args:
factor (float): Multiplicative factor for learning rate reduction (e.g., 0.1)
patience (int): Number of epochs to wait before reducing LR
cooldown (int): Number of epochs to wait after reduction before next reduction
mode (str): "min" for minimization or "max" for maximization, default "min"
epsilon (float): Minimum improvement threshold, default 1e-4
min_lr (float): Minimum learning rate boundary, default 1e-6
Raises:
AssertionError: If mode is not "min" or "max"
"""
def __init__(
self,
factor: float,
patience: int,
cooldown: int,
mode: str = "min",
epsilon: float = 1e-4,
min_lr: float = 1e-6,
) -> None:
assert mode in {"min", "max"}
self.mode = mode
self.factor = factor
self.patience = patience
self.epsilon = epsilon
self.min_lr = min_lr
self.cooldown = cooldown
self.reinit()
def reinit(self) -> None:
self.best_value = float("inf") if self.mode == "min" else -float("inf")
self.wait = 0
self.cooldown_counter = 0
def in_cooldown(self) -> bool:
return self.cooldown_counter > 0
def step(self, value: float, optimizer_dict: Dict) -> None:
"""
Args:
value (float): Current metric value to evaluate
optimizer_dict (Dict): Dictionary containing optimizer(s) to update learning rates
"""
if self.cooldown_counter > 0:
self.cooldown_counter -= 1
return
if is_improvement(value, self.best_value, self.mode, self.epsilon):
self.best_value = value
self.wait = 0
return
self.wait += 1
if self.wait >= self.patience:
for optimizer in optimizer_dict.values():
for group in optimizer.param_groups:
old_lr = group["lr"]
new_lr = max(old_lr * self.factor, self.min_lr)
group["lr"] = new_lr
self.wait = 0
self.cooldown_counter = self.cooldown
print(f"Reducing LR in the next epoch from {old_lr:.6f} to {new_lr:.6f}")
class TrainingCallback:
"""
Unified callback handler that orchestrates callback execution
during training and manages their interdependencies.
Args:
checkpoint (TrainCheckpoint, optional): Checkpoint handler, default None
early_stop (EarlyStopping, optional): Early stopping handler, default None
reduce_lr (ReduceLROnPlateau, optional): Learning rate scheduler, default None
"""
def __init__(
self,
checkpoint: Optional[TrainCheckpoint] = None,
early_stop: Optional[EarlyStopping] = None,
reduce_lr: Optional[ReduceLROnPlateau] = None,
) -> None:
self.cp = checkpoint
self.es = early_stop
self.rlr = reduce_lr
def init(self) -> None:
if self.cp:
self.cp.reinit()
if self.es:
self.es.reinit()
if self.rlr:
self.rlr.reinit()
def step(
self,
monitor_value: float,
epoch: int,
model_dict: Optional[Dict] = None,
optimizer_dict: Optional[Dict] = None,
) -> bool:
"""
Execute callbacks for the current training step.
Args:
monitor_value (float): The metric value to monitor
epoch (int): Current epoch number
model_dict (Dict, optional): Model state dict for checkpointing, default None
optimizer_dict (Dict, optional): Optimizer state dict for checkpointing, default None
Returns:
bool: True if training should stop (early stopping triggered), False otherwise
"""
# checkpoint
if self.cp and model_dict:
checkpoint_dict = {"model": model_dict, "epoch": epoch}
if optimizer_dict:
checkpoint_dict["optimizer"] = {
k: v.state_dict() if hasattr(v, "state_dict") else v
for k, v in optimizer_dict.items()
}
self.cp.step(monitor_value, checkpoint_dict)
# lr scheduler
if self.rlr and optimizer_dict:
self.rlr.step(monitor_value, optimizer_dict)
# early stop
if self.es:
early_stop = self.es.step(monitor_value)
if self.rlr:
if self.rlr.in_cooldown():
return False
return early_stop
return False
def time_formatter(sec_elapsed: float) -> str:
"""
Format elapsed time in seconds to human-readable HH:MM:SS format.
Args:
sec_elapsed (float): Elapsed time in seconds
Returns:
str: Formatted time string in the format "H:M:S"
"""
h = int(sec_elapsed / (60 * 60))
m = int((sec_elapsed % (60 * 60)) / 60)
s = sec_elapsed % 60
return f"{h}:{m}:{round(s, 1)}"