Skip to content

Commit c342566

Browse files
authored
Implement Two-Stage Fine-Tuning
Now pipeline 2 precisely follows ValizadehAslani et al.'s Two-Stage Fine-Tuning.
2 parents 73e11db + 5bc939b commit c342566

File tree

6 files changed

+174
-254
lines changed

6 files changed

+174
-254
lines changed

pipeline2/dvc.yaml

Lines changed: 8 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,11 @@
11
stages:
2-
balancing:
3-
wdir: ..
4-
cmd: python -m src.preprocessing.balance --config pipeline2/params.yaml
5-
deps:
6-
- src/preprocessing/balance.py
7-
- data/split
8-
params:
9-
- pipeline2/params.yaml:
10-
- data
11-
outs:
12-
- data_balanced
13-
14-
train:
2+
stage-1:
153
foreach: [ effnet_s, effnet_m, convnext ]
164
do:
175
wdir: ..
18-
cmd: python -m src.train --model ${item} --config pipeline2/params.yaml --output_dir pipeline2/${item}
6+
cmd: python -m src.train --model ${item} --config pipeline2/params.yaml --output_dir pipeline2/${item} --tsft true
197
deps:
20-
- data_balanced
8+
- data/split/train
219
- data/split/val
2210
- src/train.py
2311
- src/common.py
@@ -30,41 +18,16 @@ stages:
3018
- pipeline2/${item}/model.pth
3119
- pipeline2/${item}/loss.json
3220

33-
evaluate:
34-
foreach: [ effnet_s, effnet_m, convnext ]
35-
do:
36-
wdir: ..
37-
cmd: python -m src.evaluate --model ${item} --config pipeline2/params.yaml --output_dir pipeline2/${item}
38-
deps:
39-
- data/split/test
40-
- src/evaluate.py
41-
- src/common.py
42-
- pipeline2/${item}/model.pth
43-
params:
44-
- pipeline2/params.yaml:
45-
- base
46-
- data
47-
- evaluate
48-
metrics:
49-
- pipeline2/${item}/metrics.json:
50-
cache: false
51-
plots:
52-
- pipeline2/${item}/cm_data.csv:
53-
template: confusion
54-
x: actual
55-
y: predicted
56-
title: "Pipeline 2 - Balance CM - ${item}"
57-
cache: false
58-
59-
finetuning:
21+
stage-2: # fine-tuning
6022
foreach: [ effnet_s, effnet_m, convnext ]
6123
do:
6224
wdir: ..
6325
cmd: python -m src.finetune --model ${item} --config pipeline2/params.yaml --output_dir pipeline2/${item}/finetuned
6426
deps:
27+
- data/split/train
28+
- data/split/val
6529
- src/finetune.py
6630
- src/common.py
67-
- data/split/train
6831
- pipeline2/${item}/model.pth
6932
params:
7033
- pipeline2/params.yaml:
@@ -76,7 +39,7 @@ stages:
7639
- pipeline2/${item}/finetuned/loss.json
7740

7841

79-
ft-evaluate:
42+
evaluate:
8043
foreach: [ effnet_s, effnet_m, convnext ]
8144
do:
8245
wdir: ..
@@ -99,21 +62,11 @@ stages:
9962
template: confusion
10063
x: actual
10164
y: predicted
102-
title: "Pipeline 2 - Balance + Finetune CM - ${item}"
65+
title: "Pipeline 2 - Two-Stage Fine-Tuning CM - ${item}"
10366
cache: false
10467

10568

10669
plots:
107-
- Training_Loss_Comparison:
108-
template: linear
109-
x: epoch
110-
y:
111-
# Qui confrontiamo le performance di addestramento tra i modelli
112-
effnet_s/loss.json: train_loss
113-
effnet_m/loss.json: train_loss
114-
convnext/loss.json: train_loss
115-
title: "Pipeline 2 - Comparison: Training Loss per Model"
116-
11770
- effnet_s_curves:
11871
template: linear
11972
x: epoch

pipeline2/params.yaml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@ base:
22
image_res: 224
33

44
data:
5-
tobalance_path: data/split
6-
balanced_path: data_balanced/train
7-
trainset_path: data_balanced/train
5+
trainset_path: data/split/train
86
valset_path: data/split/val
97
testset_path: data/split/test
108

src/common.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
"""Common utilities for the P1 pipeline, including data loading and model setup."""
1+
"""Common utilities."""
22

33
from pathlib import Path
44

55
import torch
6+
import torch.optim as optim
67
from torch import nn
8+
from torch.amp import GradScaler, autocast
79
from torch.utils.data import DataLoader
810
from torchvision import datasets, models, transforms
911

@@ -37,7 +39,7 @@ def get_dataloader(
3739
dataset,
3840
batch_size=batch_size,
3941
shuffle=("train" in str(data_path)),
40-
num_workers=4, # Consiglio: accelera il caricamento dati
42+
num_workers=4, # Accelera il caricamento dati
4143
pin_memory=True, # Accelera il trasferimento dati alla GPU
4244
)
4345

@@ -65,3 +67,64 @@ def get_model(model_name: str, num_classes: int) -> nn.Module:
6567
model.classifier[2] = nn.Linear(model.classifier[2].in_features, num_classes)
6668

6769
return model.to(DEVICE)
70+
71+
def validate(model: nn.Module, loader: DataLoader, criterion: nn.Module) -> float:
72+
"""Calculate average loss on the validation set.
73+
74+
Args:
75+
model: The neural network model.
76+
loader: DataLoader for the validation set.
77+
criterion: Loss function (e.g., CrossEntropyLoss or LDAMLoss).
78+
79+
Returns:
80+
The average loss over the entire dataset.
81+
82+
"""
83+
model.eval()
84+
running_loss = 0.0
85+
with torch.no_grad():
86+
for images, targets in loader:
87+
images, targets = images.to(DEVICE), targets.to(DEVICE)
88+
outputs = model(images)
89+
loss = criterion(outputs, targets)
90+
running_loss += loss.item() * images.size(0)
91+
return running_loss / len(loader.dataset)
92+
93+
def train_epoch(
94+
model: nn.Module,
95+
loader: DataLoader,
96+
criterion: nn.Module,
97+
optimizer: optim.Optimizer,
98+
scaler: GradScaler,
99+
) -> float:
100+
"""Run one training epoch with AMP.
101+
102+
Args:
103+
model: The neural network model.
104+
loader: DataLoader for the training set.
105+
criterion: Loss function.
106+
optimizer: Optimizer.
107+
scaler: GradScaler for AMP.
108+
109+
Returns:
110+
Average training loss.
111+
112+
"""
113+
model.train()
114+
running_loss = 0.0
115+
116+
for images, targets in loader:
117+
images, targets = images.to(DEVICE), targets.to(DEVICE)
118+
optimizer.zero_grad()
119+
120+
with autocast(device_type=DEVICE.type):
121+
outputs = model(images)
122+
loss = criterion(outputs, targets)
123+
124+
scaler.scale(loss).backward()
125+
scaler.step(optimizer)
126+
scaler.update()
127+
128+
running_loss += loss.item() * images.size(0)
129+
130+
return running_loss / len(loader.dataset)

src/finetune.py

Lines changed: 21 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -8,104 +8,8 @@
88
import torch
99
import yaml
1010
from torch import amp, nn, optim
11-
from torch.utils.data import DataLoader
1211

13-
# Assumendo che la struttura dei package sia corretta rispetto alla root
14-
from src.common import DEVICE, get_dataloader, get_model
15-
16-
17-
def find_best_model(models_list: list[str], base_path: Path = Path(".")) -> str:
18-
"""Identify the model with the highest top1 metric.
19-
20-
Args:
21-
models_list: List of model names to check.
22-
base_path: Directory containing model folders.
23-
24-
Returns:
25-
The name of the best model.
26-
27-
"""
28-
best_top1 = -1.0
29-
best_model_name = ""
30-
31-
for model_name in models_list:
32-
metrics_path = base_path / model_name / "metrics.json"
33-
if metrics_path.exists():
34-
with metrics_path.open("r") as f:
35-
data = json.load(f)
36-
if data["top1"] > best_top1:
37-
best_top1 = data["top1"]
38-
best_model_name = model_name
39-
40-
if not best_model_name:
41-
raise FileNotFoundError(
42-
"No valid metrics.json found to determine the best model."
43-
)
44-
45-
return best_model_name
46-
47-
48-
def validate(model: nn.Module, loader: DataLoader, criterion: nn.Module) -> float:
49-
"""Calculate average loss on the validation set.
50-
51-
Args:
52-
model: The neural network model.
53-
loader: DataLoader for validation.
54-
criterion: Loss function.
55-
56-
Returns:
57-
Average validation loss.
58-
59-
"""
60-
model.eval()
61-
running_loss = 0.0
62-
with torch.no_grad():
63-
for images, targets in loader:
64-
images, targets = images.to(DEVICE), targets.to(DEVICE)
65-
outputs = model(images)
66-
loss = criterion(outputs, targets)
67-
running_loss += loss.item() * images.size(0)
68-
return running_loss / len(loader.dataset)
69-
70-
71-
def train_one_epoch(
72-
model: nn.Module,
73-
loader: DataLoader,
74-
criterion: nn.Module,
75-
optimizer: optim.Optimizer,
76-
scaler: amp.GradScaler,
77-
) -> float:
78-
"""Run one fine-tuning epoch using AMP.
79-
80-
Args:
81-
model: The neural network model.
82-
loader: DataLoader for the training set.
83-
criterion: Loss function.
84-
optimizer: Optimizer.
85-
scaler: GradScaler for AMP.
86-
87-
Returns:
88-
Average training loss.
89-
90-
"""
91-
model.train()
92-
running_loss = 0.0
93-
94-
for images, targets in loader:
95-
images, targets = images.to(DEVICE), targets.to(DEVICE)
96-
optimizer.zero_grad()
97-
98-
with amp.autocast(device_type="cuda" if torch.cuda.is_available() else "cpu"):
99-
outputs = model(images)
100-
loss = criterion(outputs, targets)
101-
102-
scaler.scale(loss).backward()
103-
scaler.step(optimizer)
104-
scaler.update()
105-
106-
running_loss += loss.item() * images.size(0)
107-
108-
return running_loss / len(loader.dataset)
12+
from src.common import DEVICE, get_dataloader, get_model, train_epoch, validate
10913

11014

11115
def main() -> None:
@@ -120,54 +24,55 @@ def main() -> None:
12024
with open(args.config) as conf_file:
12125
config: dict[str, Any] = yaml.safe_load(conf_file)
12226

123-
# 2. Setup directory di output (sovrascriviamo o creiamo una cartella fine_tuned)
12427
out_dir = Path(args.output_dir)
12528
out_dir.mkdir(parents=True, exist_ok=True)
12629

127-
# 3. Caricamento dati (Phase 2 per fine-tuning)
128-
# Nota: Assumiamo che phase2 sia la directory dei dati di training bilanciati
129-
train_loader = get_dataloader(
30+
31+
t_loader = get_dataloader(
13032
data_path=Path(config["finetuning"]["data_path"]),
13133
batch_size=config["finetuning"]["batch_size"],
13234
)
13335

134-
# Usiamo il set di validazione originale per il monitoraggio
135-
val_loader = get_dataloader(
36+
v_loader = get_dataloader(
13637
data_path=Path(config["data"]["valset_path"]),
13738
batch_size=config["finetuning"]["batch_size"],
13839
)
13940

140-
# 4. Inizializzazione modello e caricamento pesi precedenti
141-
model = get_model(args.model, len(train_loader.dataset.classes))
41+
# Model initialization loading first stage's weights
42+
model = get_model(args.model, len(t_loader.dataset.classes))
14243
weights_path = out_dir.parent / "model.pth"
143-
model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
144-
model.to(DEVICE)
44+
model.load_state_dict(torch.load(weights_path, map_location=DEVICE)).to(DEVICE)
14545

146-
# 5. Configurazione training
147-
# Per il fine-tuning si usa solitamente un Learning Rate più basso (es. 1e-5 o 1e-4)
46+
# Unfreeze layers
47+
for param in model.parameters():
48+
param.requires_grad = True
14849

50+
# Fine-tuning setup
14951
criterion = nn.CrossEntropyLoss()
15052
optimizer = optim.Adam(model.parameters(), lr=config["finetuning"]["lr"])
15153
scaler = amp.GradScaler()
15254

55+
# Model fine-tuning
15356
history = []
154-
155-
# 6. Loop di fine-tuning
15657
epochs = config["finetuning"]["epochs"]
58+
print(f"Fine-tuning {args.model}...")
15759
for epoch in range(epochs):
158-
t_loss = train_one_epoch(model, train_loader, criterion, optimizer, scaler)
159-
v_loss = validate(model, val_loader, criterion)
60+
t_loss = train_epoch(model, t_loader, criterion, optimizer, scaler)
61+
v_loss = validate(model, v_loader, criterion)
16062

16163
history.append({"epoch": epoch + 1, "train_loss": t_loss, "val_loss": v_loss})
16264

16365
print(
16466
f"Epoch {epoch + 1}/{epochs} | "
165-
f"FT Train Loss: {t_loss:.4f} | "
166-
f"FT Val Loss: {v_loss:.4f}"
67+
f"T-Loss: {t_loss:.4f} | "
68+
f"V-Loss: {v_loss:.4f}"
16769
)
70+
print(f"Model {args.model} fine-tuned successfully!")
16871

169-
# 7. Salvataggio artefatti
72+
# Saving the model
17073
torch.save(model.state_dict(), out_dir / "model.pth")
74+
75+
# Saving training and validation loss in loss.json file
17176
with open(out_dir / "loss.json", "w") as f:
17277
json.dump(history, f, indent=4)
17378

0 commit comments

Comments
 (0)