-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
115 lines (110 loc) · 5.75 KB
/
train.py
File metadata and controls
115 lines (110 loc) · 5.75 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import argparse, os, yaml, pandas as pd
import torch, torch.nn as nn
from torch.optim import AdamW
from tqdm import tqdm
from utils.seed import set_seed
from utils.data import make_loader
from utils.metrics import classification_metrics, regression_metrics
from models.fusion_model import LateFusionModel
import numpy as np
def compute_class_weights(labels):
labels = np.array(labels)
n0 = (labels==0).sum(); n1=(labels==1).sum()
if n0==0 or n1==0: return None
w0 = len(labels)/(2.0*n0); w1=len(labels)/(2.0*n1)
return torch.tensor([w0, w1], dtype=torch.float32)
def train_one_epoch(model, loader, optimizer, cfg, device, class_weights=None):
model.train()
task = cfg['task']
criterion = nn.CrossEntropyLoss(weight=class_weights.to(device) if (task=='binary' and class_weights is not None) else None) if task=='binary' else nn.HuberLoss()
step = 0; total_loss = 0.0
optimizer.zero_grad(set_to_none=True)
for batch in tqdm(loader, desc="train", leave=False):
for k in batch:
batch[k] = batch[k].to(device) if isinstance(batch[k], torch.Tensor) else batch[k]
logits, _ = model(batch, train_mode=True)
loss = criterion(logits, batch['labels']) if task=='binary' else criterion(logits.squeeze(-1), batch['labels'])
loss.backward()
if (step+1) % cfg['training']['grad_accum_steps'] == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step(); optimizer.zero_grad(set_to_none=True)
total_loss += loss.item(); step += 1
return total_loss / max(1, step)
@torch.no_grad()
def evaluate(model, loader, cfg, device):
model.eval()
task = cfg['task']
all_labels, all_scores = [], []
criterion = nn.CrossEntropyLoss() if task=='binary' else nn.HuberLoss()
total_loss = 0.0
for batch in tqdm(loader, desc="eval", leave=False):
for k in batch:
batch[k] = batch[k].to(device) if isinstance(batch[k], torch.Tensor) else batch[k]
logits, _ = model(batch, train_mode=False)
if task=='binary':
loss = criterion(logits, batch['labels'])
prob = torch.softmax(logits, dim=-1)[:,1]
all_scores.extend(prob.cpu().numpy().tolist()); all_labels.extend(batch['labels'].cpu().numpy().tolist())
else:
pred = logits.squeeze(-1)
loss = criterion(pred, batch['labels'])
all_scores.extend(pred.cpu().numpy().tolist()); all_labels.extend(batch['labels'].cpu().numpy().tolist())
total_loss += loss.item()
avg_loss = total_loss / max(1, len(loader))
if task=='binary':
from utils.metrics import classification_metrics
m = classification_metrics(np.array(all_labels), np.array(all_scores), cfg['evaluation']['threshold'])
else:
from utils.metrics import regression_metrics
m = regression_metrics(np.array(all_labels), np.array(all_scores))
return avg_loss, m
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--config", required=True)
args = parser.parse_args()
with open(args.config, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f)
set_seed(cfg['training']['seed']); os.makedirs(cfg['training']['output_dir'], exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = LateFusionModel(cfg).to(device)
# optional: freeze encoders for stability on small folds
if cfg['training'].get('freeze_encoders', False):
if getattr(model, 'audio', None) is not None:
for p in model.audio.encoder.parameters():
p.requires_grad = False
if getattr(model, 'text', None) is not None:
for p in model.text.encoder.parameters():
p.requires_grad = False
# allow overriding modality dropout probability
if 'models' in cfg and 'modality_dropout_p' in cfg['models']:
model.modality_dropout_p = cfg['models']['modality_dropout_p']
train_loader = make_loader(cfg['data']['manifest_csv'], "train", cfg, shuffle=True)
val_loader = make_loader(cfg['data']['manifest_csv'], "val", cfg, shuffle=False)
class_weights = None
if cfg['task']=="binary" and cfg['training']['class_weight'] == "auto":
df = pd.read_csv(cfg['data']['manifest_csv'])
y = df[df[cfg['data']['split_column']]=="train"][cfg['data']['label_column']].tolist()
cw = compute_class_weights(y)
class_weights = cw if cw is not None else None
# optimizer param groups: use smaller lr for fusion alpha
from itertools import chain
alpha_lr = float(cfg['training'].get('alpha_lr', cfg['training']['lr']))
other_params = [p for n, p in model.named_parameters() if p.requires_grad and n != 'alpha']
alpha_param = [model.alpha]
param_groups = [
{"params": other_params, "lr": cfg['training']['lr']},
{"params": alpha_param, "lr": alpha_lr},
]
optim = AdamW(param_groups, lr=cfg['training']['lr'], weight_decay=cfg['training']['weight_decay'])
best_score = -1e9; best_path = os.path.join(cfg['training']['output_dir'], "best.ckpt")
for epoch in range(cfg['training']['num_epochs']):
tr_loss = train_one_epoch(model, train_loader, optim, cfg, device, class_weights)
val_loss, val_metrics = evaluate(model, val_loader, cfg, device)
main_metric = val_metrics['f1_macro'] if cfg['task']=="binary" else -val_metrics['mae']
print(f"[Epoch {epoch+1}] train_loss={tr_loss:.4f} | val_loss={val_loss:.4f} | metrics={val_metrics}")
if main_metric > best_score:
best_score = main_metric; torch.save({'model': model.state_dict(), 'cfg': cfg}, best_path)
print(f"** Saved best to {best_path} (score={best_score:.4f})")
print("Training finished. Best ckpt:", best_path)
if __name__ == "__main__":
main()