Skip to content

Commit a234271

Browse files
lucasalvaaSimoCimmiMorganVitiello
committed
Add early stopping and minor changes in dvc.yaml files.
Co-authored-by: SimoCimmi <simonecimmino2004@gmail.com> Co-authored-by: Morgan Vitiello <morgan.vitiello06@gmail.com>
1 parent bf2bb46 commit a234271

File tree

9 files changed

+133
-89
lines changed

9 files changed

+133
-89
lines changed

baseline/dvc.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ stages:
33
foreach: [ effnet_s, effnet_m, convnext ]
44
do:
55
wdir: ..
6-
cmd: python -m src.train --model ${item} --config baseline/params.yaml --output_dir baseline/${item}
6+
cmd: python -m src.train --pipeline baseline --model ${item}
77
deps:
88
- data/split/train
99
- data/split/val
@@ -22,7 +22,7 @@ stages:
2222
foreach: [ effnet_s, effnet_m, convnext ]
2323
do:
2424
wdir: ..
25-
cmd: python -m src.evaluate --model ${item} --config baseline/params.yaml --output_dir baseline/${item}
25+
cmd: python -m src.evaluate --pipeline baseline --model ${item}
2626
deps:
2727
- data/split/test
2828
- src/evaluate.py

pipeline1/dvc.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ stages:
1515
foreach: [ effnet_s, effnet_m, convnext ]
1616
do:
1717
wdir: ..
18-
cmd: python -m src.train --model ${item} --config pipeline1/params.yaml --output_dir pipeline1/${item}
18+
cmd: python -m src.train --pipeline pipeline1 --model ${item}
1919
deps:
2020
- data_augmented/train
2121
- data_augmented/val
@@ -34,7 +34,7 @@ stages:
3434
foreach: [ effnet_s, effnet_m, convnext ]
3535
do:
3636
wdir: ..
37-
cmd: python -m src.evaluate --model ${item} --config pipeline1/params.yaml --output_dir pipeline1/${item}
37+
cmd: python -m src.evaluate --pipeline pipeline1 --model ${item}
3838
deps:
3939
- data_augmented/test
4040
- src/evaluate.py

pipeline2/dvc.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ stages:
33
foreach: [ effnet_s, effnet_m, convnext ]
44
do:
55
wdir: ..
6-
cmd: python -m src.train --model ${item} --config pipeline2/params.yaml --output_dir pipeline2/${item} --tsft true
6+
cmd: python -m src.train --pipeline pipeline2 --model ${item} --tsft true
77
deps:
88
- data/split/train
99
- data/split/val
@@ -21,7 +21,7 @@ stages:
2121
foreach: [ effnet_s, effnet_m, convnext ]
2222
do:
2323
wdir: ..
24-
cmd: python -m src.finetune --model ${item} --config pipeline2/params.yaml --output_dir pipeline2/${item}/finetuned
24+
cmd: python -m src.finetune --pipeline pipeline2 --model ${item}
2525
deps:
2626
- data/split/train
2727
- data/split/val
@@ -42,7 +42,7 @@ stages:
4242
foreach: [ effnet_s, effnet_m, convnext ]
4343
do:
4444
wdir: ..
45-
cmd: python -m src.evaluate --model ${item} --config pipeline2/params.yaml --output_dir pipeline2/${item}/finetuned
45+
cmd: python -m src.evaluate --pipeline pipeline2 --model ${item}
4646
deps:
4747
- data/split/test
4848
- src/evaluate.py

pipeline3/dvc.yaml

Lines changed: 21 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,82 +1,44 @@
11
stages:
2-
balancing:
3-
wdir: ..
4-
cmd: python -m src.preprocessing.balance --config pipeline3/params.yaml
5-
deps:
6-
- src/preprocessing/balance.py
7-
- data/split
8-
params:
9-
- pipeline3/params.yaml:
10-
- data
11-
outs:
12-
- pipeline3/data_balanced
13-
14-
augment:
2+
augmentation:
153
wdir: ..
164
cmd: python -m src.preprocessing.augment --config pipeline3/params.yaml
175
deps:
6+
- data/split
187
- src/preprocessing/balance.py
19-
- pipeline3/data_balanced
208
params:
219
- pipeline3/params.yaml:
2210
- data
23-
outs:
24-
- pipeline3/data_augmented
11+
# outs:
12+
# - pipeline3/data_augmented
2513

26-
train:
14+
stage-1:
2715
foreach: [ effnet_s, effnet_m, convnext ]
2816
do:
2917
wdir: ..
30-
cmd: python -m src.train --model ${item} --config pipeline3/params.yaml --output_dir pipeline3/${item}
18+
cmd: python -m src.train --pipeline pipeline3 --model ${item} --tsft true
3119
deps:
32-
- pipeline3/data_augmented
20+
- data_augmented/train
3321
- data/split/val
3422
- src/train.py
3523
- src/common.py
3624
params:
3725
- pipeline3/params.yaml:
38-
- base
39-
- data
40-
- train
41-
outs:
42-
- pipeline3/${item}/model.pth
43-
- pipeline3/${item}/loss.json
44-
45-
evaluate:
46-
foreach: [ effnet_s, effnet_m, convnext ]
47-
do:
48-
wdir: ..
49-
cmd: python -m src.evaluate --model ${item} --config pipeline3/params.yaml --output_dir pipeline3/${item}
50-
deps:
51-
- data/split/test
52-
- src/evaluate.py
53-
- src/common.py
54-
- pipeline3/${item}/model.pth
55-
params:
56-
- pipeline3/params.yaml:
5726
- base
5827
- data
59-
- evaluate
60-
metrics:
61-
- pipeline3/${item}/metrics.json:
62-
cache: false
63-
plots:
64-
- pipeline3/${item}/cm_data.csv:
65-
template: confusion
66-
x: actual
67-
y: predicted
68-
title: "Pipeline 3 - Balance + Augment CM - ${item}"
69-
cache: false
28+
- train
29+
outs:
30+
- pipeline3/${item}/model.pth
7031

71-
finetuning:
32+
stage-2: # fine-tuning
7233
foreach: [ effnet_s, effnet_m, convnext ]
7334
do:
7435
wdir: ..
75-
cmd: python -m src.finetune --model ${item} --config pipeline3/params.yaml --output_dir pipeline3/${item}/finetuned
36+
cmd: python -m src.finetune --pipeline pipeline3 --model ${item}
7637
deps:
38+
- pipeline3/data_augmented/train
39+
- data/split/val
7740
- src/finetune.py
7841
- src/common.py
79-
- data/split/train
8042
- pipeline3/${item}/model.pth
8143
params:
8244
- pipeline3/params.yaml:
@@ -87,34 +49,32 @@ stages:
8749
- pipeline3/${item}/finetuned/model.pth
8850
- pipeline3/${item}/finetuned/loss.json
8951

90-
91-
ft-evaluate:
52+
evaluate:
9253
foreach: [ effnet_s, effnet_m, convnext ]
9354
do:
9455
wdir: ..
95-
cmd: python -m src.evaluate --model ${item} --config pipeline3/params.yaml --output_dir pipeline3/${item}/finetuned
56+
cmd: python -m src.evaluate --pipeline pipeline3 --model ${item}
9657
deps:
9758
- data/split/test
9859
- src/evaluate.py
9960
- src/common.py
10061
- pipeline3/${item}/finetuned/model.pth
10162
params:
10263
- pipeline3/params.yaml:
103-
- base
104-
- data
105-
- evaluate
64+
- base
65+
- data
66+
- evaluate
10667
metrics:
10768
- pipeline3/${item}/finetuned/metrics.json:
10869
cache: false
10970
plots:
110-
- pipeline3/${item}/finetuned/cm_data.json:
71+
- pipeline3/${item}/finetuned/cm_data.csv:
11172
template: confusion
11273
x: actual
11374
y: predicted
114-
title: "Pipeline 3 - Balance + Augment + Finetune CM - ${item}"
75+
title: "Pipeline 3 - Two-Stage Fine-Tuning CM - ${item}"
11576
cache: false
11677

117-
11878
plots:
11979
- Training_Loss_Comparison:
12080
template: linear

src/early_stopping.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
5+
class EarlyStopping:
6+
"""Implement the Automatic Early Stopping technique (Lutz Prechelt, 1998)
7+
In particular, it uses the GL_alpha criterion: the training stops
8+
when the Generalization Loss is greater than the alpha value.
9+
"""
10+
11+
def __init__(self, alpha: float = 5.0, path: str = "checkpoint.pth") -> None:
12+
"""Inizializza il monitoraggio.
13+
14+
Args:
15+
alpha: Soglia percentuale di Generalization Loss (es. 5.0).
16+
path: Percorso dove salvare il miglior modello (E_opt).
17+
18+
"""
19+
self.alpha: float = alpha
20+
self.path: str = path
21+
self.min_v_loss: float = float("inf")
22+
self.best_epoch: int = 0
23+
self.stop: bool = False
24+
25+
def __call__(self, v_loss: float, epoch: int, model: nn.Module) -> None:
26+
"""Verifica la condizione di arresto.
27+
28+
Args:
29+
v_loss: Loss di validazione dell'epoca corrente.
30+
epoch: Indice dell'epoca attuale.
31+
model: Il modello da salvare in caso di miglioramento.
32+
33+
"""
34+
if v_loss < self.min_v_loss:
35+
self.min_v_loss = v_loss
36+
self.best_epoch = epoch
37+
# Salviamo il modello "ottimale" (E_opt) citato nel paper
38+
torch.save(model.state_dict(), self.path)
39+
40+
# GL(t) = 100 * (E_va(t) / E_opt(t) - 1)
41+
gl_t = 100 * (v_loss / self.min_v_loss - 1)
42+
43+
if gl_t > self.alpha:
44+
print(f"\n[Early Stopping] GL: {gl_t:.2f}% > Alpha: {self.alpha}%")
45+
self.stop = True

src/evaluate.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,18 +53,18 @@ def evaluate(
5353

5454
def main() -> None:
5555
"""Run test evaluation and save artifacts."""
56+
choices = ["baseline", "pipeline1", "pipeline2", "pipeline3"]
5657
parser = argparse.ArgumentParser()
58+
parser.add_argument("--pipeline", choices=choices, type=str, required=True)
5759
parser.add_argument("--model", type=str, required=True)
58-
parser.add_argument("--config", type=str, required=True)
59-
parser.add_argument("--output_dir", type=str, required=True)
6060
parser.add_argument("--model_path", type=str, default=None)
61-
6261
args = parser.parse_args()
6362

64-
with open(args.config) as conf_file:
65-
config = yaml.safe_load(conf_file)
63+
params_path = Path(args.pipeline) / "params.yaml"
64+
with open(params_path) as f:
65+
config = yaml.safe_load(f)
6666

67-
out_dir = Path(args.output_dir)
67+
out_dir = Path(args.pipeline) / Path(args.model)
6868
out_dir.mkdir(parents=True, exist_ok=True)
6969

7070
test_loader = get_dataloader(

src/finetune.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,28 +3,30 @@
33
import argparse
44
import json
55
from pathlib import Path
6-
from typing import Any
76

87
import torch
98
import yaml
109
from torch import amp, nn, optim
1110

1211
from src.common import DEVICE, get_dataloader, get_model, train_epoch, validate
12+
from src.early_stopping import EarlyStopping
1313

1414

1515
def main() -> None:
1616
"""Execute the fine-tuning pipeline."""
17+
choices = ["baseline", "pipeline1", "pipeline2", "pipeline3"]
1718
parser = argparse.ArgumentParser()
19+
parser.add_argument("--pipeline", choices=choices, type=str, required=True)
1820
parser.add_argument("--model", type=str, required=True)
19-
parser.add_argument("--config", type=str, required=True)
20-
parser.add_argument("--output_dir", type=str, required=True)
21-
21+
# parser.add_argument("--config", type=str, required=True)
22+
# parser.add_argument("--output_dir", type=str, required=True)
2223
args = parser.parse_args()
2324

24-
with open(args.config) as conf_file:
25-
config: dict[str, Any] = yaml.safe_load(conf_file)
25+
params_path = Path(args.pipeline) / "params.yaml"
26+
with open(params_path) as f:
27+
config = yaml.safe_load(f)
2628

27-
out_dir = Path(args.output_dir)
29+
out_dir = Path(args.pipeline / args.model) / "finetuned"
2830
out_dir.mkdir(parents=True, exist_ok=True)
2931

3032
t_loader = get_dataloader(
@@ -43,6 +45,11 @@ def main() -> None:
4345
model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
4446
model.to(DEVICE)
4547

48+
best_model_path = out_dir / "model.pth"
49+
early_stopper = EarlyStopping(
50+
alpha=config["train"].get("alpha", 5.0), path=str(best_model_path)
51+
)
52+
4653
# Unfreeze layers
4754
for param in model.parameters():
4855
param.requires_grad = True
@@ -65,6 +72,15 @@ def main() -> None:
6572
print(
6673
f"Epoch {epoch + 1}/{epochs} | T-Loss: {t_loss:.4f} | V-Loss: {v_loss:.4f}"
6774
)
75+
76+
early_stopper(v_loss, epoch + 1, model)
77+
if early_stopper.stop:
78+
print(
79+
f"Stopping at epoch {epoch + 1}. "
80+
f"Best model was at epoch {early_stopper.best_epoch}"
81+
)
82+
break
83+
6884
print(f"Model {args.model} fine-tuned successfully!")
6985

7086
# Saving the model

src/preprocessing/augment.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,19 @@ def main() -> None:
107107
"""Run data augmentation."""
108108
parser = argparse.ArgumentParser()
109109
parser.add_argument("--config", type=str, required=True)
110+
parser.add_argument("--force", type=bool, default=False)
110111
args = parser.parse_args()
111112

112113
with open(args.config) as conf_file:
113114
config = yaml.safe_load(conf_file)
114115

115-
process_dataset(
116-
Path(config["data"]["inputset_path"]), Path(config["data"]["augmentedset_path"])
117-
)
116+
input_dir = Path(config["data"]["inputset_path"])
117+
output_dir = Path(config["data"]["augmentedset_path"])
118+
119+
if output_dir.exists() and not args.force:
120+
return
121+
122+
process_dataset(input_dir, output_dir)
118123

119124

120125
if __name__ == "__main__":

0 commit comments

Comments
 (0)