-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathcontext.py
More file actions
325 lines (253 loc) · 8.78 KB
/
context.py
File metadata and controls
325 lines (253 loc) · 8.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
import torch
from torch.utils.tensorboard.writer import SummaryWriter
from pathlib import Path
import os
import json
from typing import (
List,
NamedTuple,
Optional,
TYPE_CHECKING,
)
from shutil import rmtree
from xpmir.context import InitializationHook, Hook
from xpmir.utils.utils import easylog
from xpmir.learning.devices import DeviceInformation, ComputationContext
from xpmir.learning.metrics import Metric, Metrics
from experimaestro.utils import cleanupdir
from contextlib import contextmanager
if TYPE_CHECKING:
from xpmir.learning.optim import ScheduledOptimizer, Module
from xpmir.learning.trainers import Trainer
logger = easylog()
class Loss(NamedTuple):
"""A loss"""
name: str
value: torch.Tensor
weight: float
class TrainState:
"""Represents a training state for serialization"""
MODEL_PATH = "model.pth"
epoch: int
"""The epoch"""
steps: int
"""The number of steps (each epoch is composed of sptes)"""
def __init__(
self,
model: "Module",
trainer: "Trainer",
optimizer: "ScheduledOptimizer",
epoch=0,
steps=0,
):
# Initialize the state
self.model = model
self.trainer = trainer
self.optimizer = optimizer
self.epoch = epoch
self.steps = steps
# Was it loaded from disk?
self.cached = False
# Was it saved?
self.path = None
def copy(self):
return TrainState(self.model, self.trainer, self.optimizer, **self.state_dict())
def state_dict(self):
return {
"epoch": self.epoch,
"steps": self.steps,
}
@property
def step(self):
"""Returns the step for logging (number of steps)"""
return self.steps
def load_state_dict(self, state):
self.epoch = state.get("epoch", 0)
self.steps = state.get("steps", 0)
def save(self, path):
"""Save the state"""
cleanupdir(path)
with (path / "info.json").open("wt") as fp:
json.dump(self.state_dict(), fp)
torch.save(self.model.state_dict(), path / TrainState.MODEL_PATH)
torch.save(self.trainer.state_dict(), path / "trainer.pth")
torch.save(self.optimizer.state_dict(), path / "optimizer.pth")
self.path = path
def load(self, path, onlyinfo=False):
"""Loads the state from disk"""
if not onlyinfo:
self.model.load_state_dict(torch.load(path / TrainState.MODEL_PATH))
self.trainer.load_state_dict(torch.load(path / "trainer.pth"))
self.optimizer.load_state_dict(torch.load(path / "optimizer.pth"))
with (path / "info.json").open("rt") as fp:
self.load_state_dict(json.load(fp))
self.path = path
self.cached = True
def copy_model(self, path: Path):
assert self.path is not None
for name in [TrainState.MODEL_PATH, "info.json"]:
os.link(self.path / name, path / name)
class TrainingHook(Hook):
"""Base class for all training hooks"""
pass
class ValidationHook(Hook):
"""Base class for all the validation hooks"""
def after(self, state: "TrainerContext"):
"""Called after a validation step"""
def before(self, state: "TrainerContext"):
"""Called before a validation step"""
class StepTrainingHook(TrainingHook):
"""Base class for hooks called at each step (before/after)"""
def after(self, state: "TrainerContext"):
"""Called after a training step"""
def before(self, state: "TrainerContext"):
"""Called before a training step"""
class InitializationTrainingHook(TrainingHook, InitializationHook):
"""Base class for hooks called at initialization"""
def after(self, state: "TrainerContext"):
pass
def before(self, state: "TrainerContext"):
pass
class TrainerContext(ComputationContext):
"""Contains all the information about the training context
for a spefic
This object is used when training to provide models and losses'
with extra information - as well as the possibility to add
regularization losses
"""
metrics: Optional[Metrics]
"""Metrics to be reported"""
_losses: Optional[List[Loss]]
"""Regularization losses to be added to the main loss"""
_scope: List[str]
"""Scope for metric names"""
PREFIX = "epoch-"
def __init__(
self,
device_information: DeviceInformation,
logpath: Path,
path: Path,
max_epoch: int,
steps_per_epoch: int,
trainer,
model: "Module",
optimizer: "ScheduledOptimizer",
):
super().__init__()
self.device_information = device_information
self.path = path
self.logpath = logpath
self.max_epoch = max_epoch
self.steps_per_epoch = steps_per_epoch
self._writer = None
self._scope = []
self._losses = None
self.state = TrainState(model, trainer, optimizer)
@property
def writer(self):
"""Returns a tensorboard writer
by default, purges the entries beside the current epoch
"""
if self._writer is None:
self._writer = SummaryWriter(self.logpath, purge_step=self.state.step)
return self._writer
@property
def epoch(self):
return self.state.epoch
@property
def steps(self):
return self.state.steps
def nextepoch(self):
self.oldstate = self.state
self.state = self.oldstate.copy()
self.state.epoch += 1
def nextbatch(self):
self.state.steps += 1
def load_bestcheckpoint(self, max_epoch: int):
"""Try to find the best checkpoint to load (the highest lower than
the epoch target)"""
# Find all the potential epochs
epochs = []
for f in self.path.glob(f"{TrainerContext.PREFIX}*"):
epoch = int(f.name[len(TrainerContext.PREFIX) :])
if epoch <= max_epoch:
epochs.append(epoch)
epochs.sort(reverse=True)
# Try to load the first one
for epoch in epochs:
logger.info("Loading from checkpoint at epoch %d", epoch)
path = self.path / f"{TrainerContext.PREFIX}{epoch:08d}"
try:
self.state.load(path)
return True
except NotImplementedError:
logger.error("Not removing checkpoint")
raise
except Exception:
rmtree(path)
logger.exception("Cannot load from epoch %d", epoch)
return False
@staticmethod
def get_checkpoint_path(checkpointspath: Path, epoch: int) -> Path:
return checkpointspath / f"{TrainerContext.PREFIX}{epoch:08d}"
def save_checkpoint(self):
# Serialize
path = TrainerContext.get_checkpoint_path(self.path, self.epoch)
if self.state.path is not None:
# No need to save twice
return
# Save
self.state.save(path)
# Cleanup if needed
if self.oldstate and self.oldstate.path:
try:
rmtree(self.oldstate.path)
except OSError as e:
# We continue the learning process in those cases
logger.error("OS Error while trying to remove directory %s", e)
self.oldstate = None
def copy(self, path: Path):
"""Copy the state into another folder"""
if self.state.path is None:
self.save_checkpoint()
trainpath = self.state.path
assert trainpath is not None
if path:
cleanupdir(path)
self.state.copy_model(path)
def add_loss(self, loss: Loss):
assert (
self._losses is not None
), "This should be called in the context where loss is computed"
self._losses.append(loss)
@contextmanager
def losses(self):
previous = self._losses
try:
self._losses = []
yield self._losses
finally:
self._losses = previous
@contextmanager
def step(self, metrics):
try:
self.state.optimizer.zero_grad()
self.metrics = Metrics()
yield self.metrics
self.state.optimizer.optimizer_step(self)
self.state.optimizer.scheduler_step(self)
metrics.merge(self.metrics)
finally:
self.metrics = None
def add_metric(self, metric: Metric):
assert self.metrics is not None, "Not within an optimization step"
if self._scope:
metric.key = "/".join(s for s in self._scope if s) + "/" + metric.key
self.metrics.add(metric)
@contextmanager
def scope(self, name: str):
try:
self._scope.append(name)
yield self
finally:
self._scope.pop()