Skip to content

Commit 8eed347

Browse files
committed
Improve train code:
* Move general stuff from pose estimator training to train.py * Autoformat * Add smoketest back in
1 parent 320b39c commit 8eed347

File tree

4 files changed

+382
-233
lines changed

4 files changed

+382
-233
lines changed

run.sh

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,9 @@
11
#!/bin/bash
22

3-
# python scripts/train_poseestimator.py --lr 1.e-3 --epochs 500 --ds "repro_300_wlp+lapa_megaface_lp+wflw_lp+synface" \
4-
# --save-plot train.pdf \
5-
# --with-swa \
6-
# --with-nll-loss \
7-
# --roi-override original \
8-
# --no-onnx \
9-
# --backbone mobilenetv1 \
10-
# --outdir model_files/
11-
12-
13-
14-
#--rampup_nll_losses \
15-
16-
python scripts/train_poseestimator_lightning.py --ds "repro_300_wlp+lapa_megaface_lp+wflw_lp+synface" \
17-
--epochs 10 \
3+
python scripts/train_poseestimator.py --lr 1.e-3 --epochs 1500 --ds "repro_300_wlp+lapa_megaface_lp+wflw_lp+synface" \
184
--with-swa \
195
--with-nll-loss \
20-
--rampup-nll-losses
6+
--backbone hybrid_vit \
7+
--rampup-nll-losses
8+
9+
# --outdir model_files/current/run0/

scripts/train_poseestimator.py

Lines changed: 6 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,8 @@
1818
import tqdm
1919

2020
import pytorch_lightning as pl
21-
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
21+
from pytorch_lightning.callbacks import ModelCheckpoint
2222

23-
# from pytorch_lightning.loggers import Logger,
24-
from pytorch_lightning.utilities import rank_zero_only
2523
import torch.optim as optim
2624
import torch
2725
import torch.nn as nn
@@ -32,8 +30,6 @@
3230
import trackertraincode.train as train
3331
import trackertraincode.pipelines
3432

35-
from trackertraincode.neuralnets.io import complement_lightning_checkpoint
36-
from scripts.export_model import convert_posemodel_onnx
3733
from trackertraincode.datasets.batch import Batch
3834
from trackertraincode.pipelines import Tag
3935

@@ -161,11 +157,6 @@ def create_optimizer(net, args: MyArgs):
161157
return optimizer, scheduler
162158

163159

164-
class SaveBestSpec(NamedTuple):
165-
weights: List[float]
166-
names: List[str]
167-
168-
169160
def setup_losses(args: MyArgs, net):
170161
C = train.Criterion
171162
cregularize = [
@@ -259,9 +250,7 @@ def wrapped(step):
259250
),
260251
}
261252

262-
savebest = SaveBestSpec([1.0, 1.0, 1.0], ["rot", "xy", "sz"])
263-
264-
return train_criterions, test_criterions, savebest
253+
return train_criterions, test_criterions
265254

266255

267256
def create_net(args: MyArgs):
@@ -281,7 +270,7 @@ def __init__(self, args: MyArgs):
281270
super().__init__()
282271
self._args = args
283272
self._model = create_net(args)
284-
train_criterions, test_criterions, savebest = setup_losses(args, self._model)
273+
train_criterions, test_criterions = setup_losses(args, self._model)
285274
self._train_criterions = train_criterions
286275
self._test_criterions = test_criterions
287276

@@ -315,120 +304,6 @@ def model(self):
315304
return self._model
316305

317306

318-
class SwaCallback(Callback):
319-
def __init__(self, start_epoch):
320-
super().__init__()
321-
self._swa_model: optim.swa_utils.AveragedModel | None = None
322-
self._start_epoch = start_epoch
323-
324-
@property
325-
def swa_model(self):
326-
return self._swa_model.module
327-
328-
def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
329-
assert isinstance(pl_module, LitModel)
330-
self._swa_model = optim.swa_utils.AveragedModel(pl_module.model, device="cpu", use_buffers=True)
331-
332-
def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
333-
assert isinstance(pl_module, LitModel)
334-
if trainer.current_epoch > self._start_epoch:
335-
self._swa_model.update_parameters(pl_module.model)
336-
337-
def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
338-
assert self._swa_model is not None
339-
swa_filename = join(trainer.default_root_dir, f"swa.ckpt")
340-
models.save_model(self._swa_model.module, swa_filename)
341-
342-
343-
class MetricsGraphing(Callback):
344-
def __init__(self):
345-
super().__init__()
346-
self._visu: train.TrainHistoryPlotter | None = None
347-
self._metrics_accumulator = defaultdict(list)
348-
349-
def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
350-
assert self._visu is None
351-
self._visu = train.TrainHistoryPlotter(save_filename=join(trainer.default_root_dir, "train.pdf"))
352-
353-
def on_train_batch_end(
354-
self, trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: Any, batch: Any, batch_idx: int
355-
):
356-
mt_losses: dict[str, torch.Tensor] = outputs["mt_losses"]
357-
for k, v in mt_losses.items():
358-
self._visu.add_train_point(trainer.current_epoch, batch_idx, k, v.numpy())
359-
self._visu.add_train_point(trainer.current_epoch, batch_idx, "loss", outputs["loss"].detach().cpu().numpy())
360-
361-
def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
362-
if trainer.lr_scheduler_configs: # scheduler is not None:
363-
scheduler = next(
364-
iter(trainer.lr_scheduler_configs)
365-
).scheduler # Pick the first scheduler (and there should only be one)
366-
last_lr = next(iter(scheduler.get_last_lr())) # LR from the first parameter group
367-
self._visu.add_test_point(trainer.current_epoch, "lr", last_lr)
368-
369-
self._visu.summarize_train_values()
370-
371-
def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
372-
self._metrics_accumulator = defaultdict(list)
373-
374-
def on_validation_batch_end(
375-
self,
376-
trainer: pl.Trainer,
377-
pl_module: pl.LightningModule,
378-
outputs: list[train.LossVal],
379-
batch: Any,
380-
batch_idx: int,
381-
dataloader_idx: int = 0,
382-
) -> None:
383-
for val in outputs:
384-
self._metrics_accumulator[val.name].append(val.val)
385-
386-
def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
387-
if self._visu is None:
388-
return
389-
for k, v in self._metrics_accumulator.items():
390-
self._visu.add_test_point(trainer.current_epoch - 1, k, torch.cat(v).mean().cpu().numpy())
391-
if trainer.current_epoch > 0:
392-
self._visu.update_graph()
393-
394-
395-
class SimpleProgressBar(Callback):
396-
"""Creates progress bars for total training time and progress of per epoch."""
397-
398-
def __init__(self, batchsize: int):
399-
super().__init__()
400-
self._bar: tqdm.tqdm | None = None
401-
self._epoch_bar: tqdm.tqdm | None = None
402-
self._batchsize = batchsize
403-
404-
def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
405-
self._bar = tqdm.tqdm(total=trainer.max_epochs, desc='Training', position=0)
406-
self._epoch_bar = tqdm.tqdm(total=trainer.num_training_batches * self._batchsize, desc="Epoch", position=1)
407-
408-
def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
409-
self._bar.close()
410-
self._epoch_bar.close()
411-
self._bar = None
412-
self._epoch_bar = None
413-
414-
def on_train_epoch_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
415-
self._epoch_bar.reset(self._epoch_bar.total)
416-
417-
def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
418-
self._bar.update(1)
419-
420-
def on_train_batch_end(
421-
self,
422-
trainer: pl.Trainer,
423-
pl_module: pl.LightningModule,
424-
outputs: Mapping[str, Any],
425-
batch: list[Batch] | Batch,
426-
batch_idx: int,
427-
) -> None:
428-
n = sum(b.meta.batchsize for b in batch) if isinstance(batch, list) else batch.meta.batchsize
429-
self._epoch_bar.update(n)
430-
431-
432307
def main():
433308
np.seterr(all="raise")
434309
cv2.setNumThreads(1)
@@ -499,13 +374,13 @@ def main():
499374
save_weights_only=False,
500375
)
501376

502-
progress_cb = SimpleProgressBar(args.batchsize)
377+
progress_cb = train.SimpleProgressBar(args.batchsize)
503378

504-
callbacks = [MetricsGraphing(), checkpoint_cb, progress_cb]
379+
callbacks = [train.MetricsGraphing(), checkpoint_cb, progress_cb]
505380

506381
swa_callback = None
507382
if args.swa:
508-
swa_callback = SwaCallback(start_epoch=args.epochs * 2 // 3)
383+
swa_callback = train.SwaCallback(start_epoch=args.epochs * 2 // 3)
509384
callbacks.append(swa_callback)
510385

511386
# TODO: inf norm?

test/test_train.py

Lines changed: 138 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,154 @@
11
from torch.utils.data import Dataset, DataLoader
22
import time
33
import torch
4+
from torch import nn
45
import numpy as np
6+
import os
57
import functools
6-
from typing import List
7-
from trackertraincode.datasets.batch import Batch, Metadata
8+
from typing import List, Any
9+
import itertools
10+
from pytorch_lightning.callbacks import ModelCheckpoint
11+
import pytorch_lightning as pl
12+
import matplotlib
13+
import matplotlib.pyplot
14+
import time
815

16+
from trackertraincode.datasets.batch import Batch, Metadata
917
import trackertraincode.train as train
1018

19+
1120
def test_plotter():
1221
plotter = train.TrainHistoryPlotter()
13-
names = [ 'foo', 'bar', 'baz', 'lr' ]
22+
names = ['foo', 'bar', 'baz', 'lr']
1423
for e in range(4):
1524
for t in range(5):
1625
for name in names[:-2]:
17-
plotter.add_train_point(e, t, name, 10. + e + np.random.normal(0., 1.,(1,)))
26+
plotter.add_train_point(e, t, name, 10.0 + e + np.random.normal(0.0, 1.0, (1,)))
1827
for name in names[1:]:
19-
plotter.add_test_point(e, name, 9. + e + np.random.normal())
28+
plotter.add_test_point(e, name, 9.0 + e + np.random.normal())
2029
plotter.summarize_train_values()
2130
plotter.update_graph()
22-
plotter.close()
31+
plotter.close()
32+
33+
34+
class MseLoss(object):
35+
def __call__(self, pred, batch):
36+
return torch.nn.functional.mse_loss(pred['test_head_out'], batch['y'], reduction='none')
37+
38+
39+
class L1Loss(object):
40+
def __call__(self, pred, batch):
41+
return torch.nn.functional.l1_loss(pred['test_head_out'], batch['y'], reduction='none')
42+
43+
44+
class CosineDataset(Dataset):
45+
def __init__(self, n):
46+
self.n = n
47+
48+
def __len__(self):
49+
return self.n
50+
51+
def __getitem__(self, i):
52+
x = torch.rand((1,))
53+
y = torch.cos(x)
54+
return Batch(Metadata(0, batchsize=0), {'image': x, 'y': y})
55+
56+
57+
class MockupModel(nn.Module):
58+
def __init__(self):
59+
super().__init__()
60+
self.layers = torch.nn.Sequential(torch.nn.Linear(1, 128), torch.nn.ReLU(), torch.nn.Linear(128, 1))
61+
62+
def forward(self, x: torch.Tensor):
63+
return {'test_head_out': self.layers(x)}
64+
65+
def get_config(self):
66+
return {}
67+
68+
69+
class LitModel(pl.LightningModule):
70+
def __init__(self):
71+
super().__init__()
72+
self._model = MockupModel()
73+
self._train_criterions = self.__setup_train_criterions()
74+
self._test_criterion = train.Criterion('test_head_out_c1', MseLoss(), 1.0)
75+
76+
def __setup_train_criterions(self):
77+
c1 = train.Criterion('c1', MseLoss(), 0.42)
78+
c2 = train.Criterion('c2', L1Loss(), 0.7)
79+
return train.CriterionGroup([c1, c2], 'test_head_out_')
80+
81+
def training_step(self, batch: Batch, batch_idx):
82+
loss_sum, all_lossvals = train.default_compute_loss(
83+
self._model, [batch], self.current_epoch, self._train_criterions
84+
)
85+
loss_val_by_name = {
86+
name: val
87+
for name, (val, _) in train.concatenated_lossvals_by_name(
88+
itertools.chain.from_iterable(all_lossvals)
89+
).items()
90+
}
91+
self.log("loss", loss_sum, on_epoch=True, prog_bar=True, batch_size=batch.meta.batchsize)
92+
return {"loss": loss_sum, "mt_losses": loss_val_by_name}
93+
94+
def validation_step(self, batch: Batch, batch_idx: int) -> torch.Tensor | dict[str, Any] | None:
95+
images = batch["image"]
96+
pred = self._model(images)
97+
values = self._test_criterion.evaluate(pred, batch, batch_idx)
98+
val_loss = torch.cat([(lv.val * lv.weight) for lv in values]).sum()
99+
self.log("val_loss", val_loss, on_epoch=True, batch_size=batch.meta.batchsize)
100+
return values
101+
102+
def configure_optimizers(self):
103+
optimizer = torch.optim.Adam(self.model.parameters(), lr=1.0e-4)
104+
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1.0e-4, total_steps=50)
105+
return {"optimizer": optimizer, "lr_scheduler": scheduler}
106+
107+
@property
108+
def model(self):
109+
return self._model
110+
111+
112+
def test_train_smoketest(tmp_path):
113+
batchsize = 32
114+
epochs = 50
115+
train_loader = DataLoader(CosineDataset(20), batch_size=batchsize, collate_fn=Batch.collate)
116+
test_loader = DataLoader(CosineDataset(8), batch_size=batchsize, collate_fn=Batch.collate)
117+
model = LitModel()
118+
model_out_dir = os.path.join(tmp_path, 'models')
119+
120+
checkpoint_cb = ModelCheckpoint(
121+
save_top_k=1,
122+
save_last=True,
123+
monitor="val_loss",
124+
enable_version_counter=False,
125+
filename="best",
126+
dirpath=model_out_dir,
127+
save_weights_only=False,
128+
)
129+
130+
progress_cb = train.SimpleProgressBar(batchsize)
131+
visu_cb = train.MetricsGraphing()
132+
callbacks = [visu_cb, checkpoint_cb, progress_cb, train.SwaCallback(start_epoch=epochs // 2)]
133+
134+
trainer = pl.Trainer(
135+
fast_dev_run=False,
136+
gradient_clip_val=1.0,
137+
gradient_clip_algorithm="norm",
138+
default_root_dir=model_out_dir,
139+
# limit_train_batches=((10 * 1024) // batchsize),
140+
callbacks=callbacks,
141+
enable_checkpointing=True,
142+
max_epochs=epochs,
143+
log_every_n_steps=1,
144+
logger=False,
145+
enable_progress_bar=False,
146+
)
147+
148+
trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=test_loader)
149+
150+
visu_cb.close()
151+
152+
assert os.path.isfile(tmp_path / 'models' / 'swa.ckpt')
153+
assert os.path.isfile(tmp_path / 'models' / 'best.ckpt')
154+
assert os.path.isfile(tmp_path / 'models' / 'train.pdf')

0 commit comments

Comments
 (0)