-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
828 lines (708 loc) · 27.6 KB
/
train.py
File metadata and controls
828 lines (708 loc) · 27.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
"""
train.py
End-to-end training script:
1) Train teacher ResNet50 on public mammography dataset.
2) Train dual-branch student with:
- teacher-student distillation,
- gradient alignment,
- policy-based teacher updates,
- replay regularization,
- EWC/CMD-style regularizers.
All key hyperparameters and dataset paths are defined in the
`if __name__ == "__main__"`
"""
import math
import time
from typing import List, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from sklearn.metrics import confusion_matrix, accuracy_score
import torchvision.transforms as transforms
from datasets import (
set_seed,
create_teacher_dataloaders,
PriorCurrentDataset,
create_weighted_sampler,
)
from models import (
create_teacher_model,
partial_freeze_resnet50,
DualBranchStudent,
TeacherFeatureExtractor,
)
from losses import (
student_distillation_loss_dual_feature,
teacher_distillation_loss,
flatten_grad,
load_flattened_grad,
_layer_variances,
covariance_cmd,
)
# ---------------------------------------------------------------------------
# Basic evaluation helpers for teacher
# ---------------------------------------------------------------------------
def evaluate_teacher(
model: nn.Module, loader: DataLoader, device: torch.device
) -> Tuple[float, float]:
"""
Evaluate teacher model with BCE loss and accuracy.
"""
criterion = nn.BCEWithLogitsLoss()
model.eval()
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for images, labels in loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images).squeeze(1)
loss = criterion(outputs, labels)
running_loss += loss.item() * labels.size(0)
preds = (torch.sigmoid(outputs) > 0.5).long()
correct += (preds == labels.long()).sum().item()
total += labels.size(0)
avg_loss = running_loss / total if total > 0 else 0.0
acc = correct / total if total > 0 else 0.0
return avg_loss, acc
def evaluate_teacher_confusion(
model: nn.Module, loader: DataLoader, device: torch.device
) -> Tuple[float, float, np.ndarray]:
"""
Evaluate teacher and also return confusion matrix.
"""
criterion = nn.BCEWithLogitsLoss()
model.eval()
running_loss = 0.0
total = 0
all_preds: List[int] = []
all_labels: List[int] = []
with torch.no_grad():
for images, labels in loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images).squeeze(1)
loss = criterion(outputs, labels)
running_loss += loss.item() * labels.size(0)
total += labels.size(0)
preds = (torch.sigmoid(outputs) > 0.5).long().cpu().numpy()
all_preds.extend(preds)
all_labels.extend(labels.long().cpu().numpy())
avg_loss = running_loss / total if total > 0 else 0.0
acc = accuracy_score(all_labels, all_preds)
cm = confusion_matrix(all_labels, all_preds)
return avg_loss, acc, cm # type: ignore
def train_teacher_alone(
model: nn.Module,
train_loader: DataLoader,
val_loader: DataLoader,
device: torch.device,
epochs: int = 10,
lr: float = 1e-4,
weight_decay: float = 1e-4,
) -> nn.Module:
"""
Train the teacher model alone on the public dataset.
"""
criterion = nn.BCEWithLogitsLoss()
model.to(device)
partial_freeze_resnet50(model)
optimizer = optim.Adam(
filter(lambda p: p.requires_grad, model.parameters()),
lr=lr,
weight_decay=weight_decay,
)
scheduler = StepLR(optimizer, step_size=5, gamma=0.1)
for epoch in range(epochs):
model.train()
running_loss = 0.0
correct = 0
total = 0
for images, labels in train_loader:
images = images.to(device)
labels = labels.to(device)
optimizer.zero_grad()
outputs = model(images).squeeze(1)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * labels.size(0)
preds = (torch.sigmoid(outputs) > 0.5).long()
correct += (preds == labels.long()).sum().item()
total += labels.size(0)
scheduler.step()
train_loss = running_loss / total
train_acc = correct / total
val_loss, val_acc = evaluate_teacher(model, val_loader, device)
print(
f"[Teacher Only] Epoch {epoch+1}/{epochs} "
f"- Train Loss: {train_loss:.4f}, Acc: {train_acc:.4f} | "
f"Val Loss: {val_loss:.4f}, Acc: {val_acc:.4f}"
)
return model
# ---------------------------------------------------------------------------
# Student evaluation helpers
# ---------------------------------------------------------------------------
def evaluate_student_alone(
model: nn.Module, loader: DataLoader, device: torch.device
) -> Tuple[float, float]:
"""
Evaluate student based on fused output only (no teacher).
"""
criterion = nn.BCEWithLogitsLoss()
model.eval()
running_loss = 0.0
correct = 0
total = 0
with torch.no_grad():
for prior_img, current_img, labels in loader:
prior_img = prior_img.to(device)
current_img = current_img.to(device)
labels = labels.to(device)
fused_logit, _, _, _, _ = model(prior_img, current_img)
loss = criterion(fused_logit, labels)
running_loss += loss.item() * labels.size(0)
preds = (torch.sigmoid(fused_logit) > 0.5).long()
correct += (preds == labels.long()).sum().item()
total += labels.size(0)
return running_loss / total if total > 0 else 0.0, correct / total if total > 0 else 0.0
def evaluate_student_confusion(
model: nn.Module, loader: DataLoader, device: torch.device
) -> Tuple[float, float, np.ndarray]:
"""
Returns (avg_loss, acc, confusion_matrix) for student's fused output.
"""
criterion = nn.BCEWithLogitsLoss()
model.eval()
all_preds: List[int] = []
all_labels: List[int] = []
running_loss = 0.0
total = 0
with torch.no_grad():
for prior_img, current_img, labels in loader:
prior_img = prior_img.to(device)
current_img = current_img.to(device)
labels = labels.to(device)
fused_logit, _, _, _, _ = model(prior_img, current_img)
loss = criterion(fused_logit, labels)
running_loss += loss.item() * labels.size(0)
total += labels.size(0)
preds = (torch.sigmoid(fused_logit) > 0.5).long().cpu().numpy()
all_preds.extend(preds)
all_labels.extend(labels.long().cpu().numpy())
avg_loss = running_loss / total if total > 0 else 0.0
acc = accuracy_score(all_labels, all_preds)
cm = confusion_matrix(all_labels, all_preds)
return avg_loss, acc, cm # type: ignore
# ---------------------------------------------------------------------------
# Main student training loop with policy-based teacher updates
# ---------------------------------------------------------------------------
def train_student_with_teacher(
student_model: DualBranchStudent,
teacher_model: nn.Module,
teacher_extractor: TeacherFeatureExtractor,
pc_train_loader: DataLoader,
pc_val_loader: DataLoader,
device: torch.device,
*,
public_loader: DataLoader, # replay loader
epochs: int = 10,
alpha: float = 0.9,
beta: float = 0.2,
T: float = 2.0,
gamma: float = 0.3,
student_lr: float = 1e-4,
# Teacher settings
teacher_lr: float = 1e-6,
weight_decay: float = 1e-4,
tau: float = 0.10,
confidence_threshold: float = 0.6,
WARMUP_EPOCHS: int = 2,
teacher_interp_lambda: float = 0.99,
gap_threshold: float = 0.03,
stab_threshold: float = 0.004,
window_size: int = 5,
# Regularizer hyper-params
lambda_EWC: float = 1e-2,
lambda_Cov: float = 1e-4,
lr_cov_update: float = 1e-5,
) -> DualBranchStudent:
"""
Joint training of student and teacher with:
- Gradient alignment (student directions aligned with teacher-driven loss).
- Policy-based teacher updates (only update teacher when val metrics improve
and are stable).
- Covariance-based and EWC-style regularization on teacher weights.
- Replay on public data to avoid catastrophic forgetting.
"""
# Replay iterator over public data (teacher dataset)
public_iter = iter(
DataLoader(
public_loader.dataset, batch_size=16, shuffle=True, drop_last=True
)
)
# Move models to device
student_model.to(device)
teacher_model.to(device)
teacher_extractor.to(device)
# Student optimizer
student_opt = optim.Adam(
filter(lambda p: p.requires_grad, student_model.parameters()),
lr=student_lr,
weight_decay=weight_decay,
)
# Partial freeze teacher: only layer4 + fc trainable
partial_freeze_resnet50(teacher_model)
for n, p in teacher_model.named_parameters():
if "layer4" in n or "fc" in n:
p.requires_grad = True
teacher_opt = optim.Adam(
filter(lambda p: p.requires_grad, teacher_model.parameters()),
lr=teacher_lr,
weight_decay=weight_decay,
)
sched_student = StepLR(student_opt, step_size=5, gamma=0.1)
sched_teacher = StepLR(teacher_opt, step_size=5, gamma=0.1)
# Helper: evaluate student using KD loss on validation set
def _eval_student():
student_model.eval()
teacher_extractor.eval()
criterion = nn.BCEWithLogitsLoss()
tot, corr = 0, 0
loss_sum, conf_sum = 0.0, 0.0
with torch.no_grad():
for pri, cur, y in pc_val_loader:
pri, cur, y = pri.to(device), cur.to(device), y.to(device)
fused, sp, sc, fp, fc = student_model(pri, cur)
tp, tfp = teacher_extractor(
pri, return_features=True, project_features=True
)
tc, tfc = teacher_extractor(
cur, return_features=True, project_features=True
)
loss = student_distillation_loss_dual_feature(
fused,
sp,
sc,
fp,
fc,
tp,
tc,
tfp,
tfc,
y,
alpha=alpha,
beta=beta,
T=T,
gamma=gamma,
)
loss_sum += loss.item() * y.size(0)
preds = (torch.sigmoid(fused) > 0.5).long()
corr += (preds == y.long()).sum().item()
tot += y.size(0)
conf_sum += torch.abs(torch.sigmoid(fused) - 0.5).sum().item()
if tot == 0:
return 0.0, 0.0, 0.0
return loss_sum / tot, corr / tot, conf_sum / tot
base_loss, base_acc, base_conf = _eval_student()
print(
f"Baseline student – ValLoss={base_loss:.4f} "
f"Acc={base_acc:.4f} Conf={base_conf:.4f}"
)
prev_val_loss = base_loss
val_acc_window: List[float] = []
# Main epoch loop
for epoch in range(epochs):
student_model.train()
teacher_extractor.eval()
run_loss, run_corr, run_tot = 0.0, 0, 0
# ----------- Student update --------------------------------------
for pri, cur, y in pc_train_loader:
pri, cur, y = pri.to(device), cur.to(device), y.to(device)
# 1) Gradient in teacher direction only (distillation-focused)
student_opt.zero_grad()
fused, sp, sc, fp, fc = student_model(pri, cur)
with torch.no_grad():
tp, tfp = teacher_extractor(
pri, return_features=True, project_features=True
)
tc, tfc = teacher_extractor(
cur, return_features=True, project_features=True
)
dist_only = student_distillation_loss_dual_feature(
fused,
sp,
sc,
fp,
fc,
tp,
tc,
tfp,
tfc,
y,
alpha=5.0, # emphasis on teacher signal
beta=beta,
T=T,
gamma=gamma,
)
dist_only.backward(retain_graph=True)
g_Tdir = flatten_grad(list(student_model.parameters()))
if g_Tdir is not None:
g_Tdir = g_Tdir.detach()
# 2) Gradient of full loss (teacher + hard labels)
student_opt.zero_grad()
full_loss = student_distillation_loss_dual_feature(
fused,
sp,
sc,
fp,
fc,
tp,
tc,
tfp,
tfc,
y,
alpha=alpha,
beta=beta,
T=T,
gamma=gamma,
)
full_loss.backward()
g_full = flatten_grad(list(student_model.parameters()))
# 3) Project full gradient onto teacher direction
if g_Tdir is not None and g_full is not None:
cos = (g_full.dot(g_Tdir) / (g_full.norm() * g_Tdir.norm() + 1e-9)).item()
lam = math.exp(tau * cos) / math.exp(tau)
proj = g_full.dot(g_Tdir) / (g_Tdir.dot(g_Tdir) + 1e-9)
g_adapt = lam * proj * g_Tdir
load_flattened_grad(list(student_model.parameters()), g_adapt)
student_opt.step()
run_loss += full_loss.item() * y.size(0)
run_corr += (
(torch.sigmoid(fused) > 0.5)
.long()
.eq(y.long())
.sum()
.item()
)
run_tot += y.size(0)
sched_student.step()
tr_loss, tr_acc = run_loss / run_tot, run_corr / run_tot
val_loss, val_acc, val_conf = _eval_student()
print(
f"Epoch {epoch+1}/{epochs} – TrainLoss {tr_loss:.4f} "
f"TrainAcc {tr_acc:.4f} | ValLoss {val_loss:.4f} "
f"ValAcc {val_acc:.4f}"
)
# ------------------------------------------------------------------
# Decide if teacher should be updated (policy)
# ------------------------------------------------------------------
if epoch < WARMUP_EPOCHS:
print(" Teacher update skipped (warm-up phase).")
prev_val_loss = val_loss
val_acc_window.append(val_acc)
if len(val_acc_window) > window_size:
val_acc_window.pop(0)
continue
delta = (
(prev_val_loss - val_loss) / (prev_val_loss + 1e-9)
if prev_val_loss > 0
else 0.0
)
cond_orig = (delta > tau) or (val_conf > confidence_threshold)
acc_gap = abs(tr_acc - val_acc)
cond_gap = acc_gap < gap_threshold
val_acc_window.append(val_acc)
if len(val_acc_window) > window_size:
val_acc_window.pop(0)
var_val = torch.var(torch.tensor(val_acc_window)).item() if val_acc_window else 0.0
cond_stab = var_val < stab_threshold
update_teacher = cond_orig and cond_gap and cond_stab
if update_teacher:
# Enable grads only on unfrozen parts
teacher_model.train()
for p in teacher_model.parameters():
p.requires_grad = False
for n, p in teacher_model.named_parameters():
if "layer4" in n or "fc" in n:
p.requires_grad = True
t_params = [p for p in teacher_model.parameters() if p.requires_grad]
W_old = [p.clone().detach() for p in t_params]
var_old = _layer_variances(W_old)
# One pass over training data to update the teacher
for pri, cur, y in pc_train_loader:
pri, cur, y = pri.to(device), cur.to(device), y.to(device)
teacher_opt.zero_grad()
# Step 1: teacher loss on current batch
t_logits = teacher_model(cur).squeeze(1)
with torch.no_grad():
_, _, sc, _, _ = student_model(pri, cur)
base_t_loss = teacher_distillation_loss(
t_logits, sc, y, alpha=0.7, T=T
)
# Step 1b: replay loss on old (public) data
try:
pub_imgs, pub_labels = next(public_iter)
except StopIteration:
public_iter = iter(
DataLoader(
public_loader.dataset,
batch_size=16,
shuffle=True,
drop_last=True,
)
)
pub_imgs, pub_labels = next(public_iter)
pub_imgs, pub_labels = pub_imgs.to(device), pub_labels.to(device)
pub_logits = teacher_model(pub_imgs).squeeze(1)
replay_loss = teacher_distillation_loss(
pub_logits,
pub_logits.detach(),
pub_labels,
alpha=1.0,
T=T,
)
base_t_loss = base_t_loss + replay_loss
# Step 2: gradients for projection logic
base_t_loss.backward(retain_graph=True)
g_T = flatten_grad(t_params) # type: ignore
t_prob = torch.sigmoid(t_logits / T)
s_prob = torch.sigmoid(sc / T)
proxy_dist = (1 - 0.7) * nn.KLDivLoss(reduction="batchmean")(
torch.log(t_prob + 1e-7), s_prob
)
g_S_tuple = torch.autograd.grad(
proxy_dist, t_params, retain_graph=True, allow_unused=True
)
g_S = torch.cat(
[
(g if g is not None else torch.zeros_like(p)).view(-1)
for g, p in zip(g_S_tuple, t_params)
]
)
# Step 3: projection & scaling
cos = (g_T.dot(g_S) / (g_T.norm() * g_S.norm() + 1e-9)).item() # type: ignore
lam = math.exp(tau * cos) / math.exp(tau)
proj = g_T.dot(g_S) / (g_T.dot(g_T) + 1e-9) # type: ignore
g_aligned = lam * proj * g_T # type: ignore
# Step 4: temporary parameter update for regularization terms
W_temp, off = [], 0
for p in t_params:
n = p.numel()
grad_slice = g_aligned[off : off + n].view_as(p)
W_temp.append(p.detach() - lr_cov_update * grad_slice)
off += n
# Step 5: regularisation (CMD + EWC-style)
var_temp = _layer_variances(W_temp)
cmd = covariance_cmd(var_temp, var_old)
ewc = sum(
(w_t - w_o).pow(2).sum()
for w_t, w_o in zip(W_temp, W_old)
)
reg = lambda_Cov * cmd + lambda_EWC * ewc
# Step 6: final composite loss & parameter update
final_loss = base_t_loss + reg
teacher_opt.zero_grad()
final_loss.backward()
teacher_opt.step()
sched_teacher.step()
# Interpolate teacher weights toward student weights
with torch.no_grad():
t_sd = teacher_model.state_dict()
s_sd = student_model.state_dict()
for k in t_sd:
if k in s_sd and t_sd[k].shape == s_sd[k].shape:
t_sd[k].mul_(teacher_interp_lambda).add_(
s_sd[k] * (1 - teacher_interp_lambda)
)
# Re-freeze
for p in teacher_model.parameters():
p.requires_grad = False
partial_freeze_resnet50(teacher_model)
teacher_model.eval()
print(" Teacher updated ✅")
else:
print(" Teacher not updated (criteria unmet).")
prev_val_loss = val_loss
print("―" * 72)
return student_model
# ---------------------------------------------------------------------------
# Main entry point
# ---------------------------------------------------------------------------
if __name__ == "__main__":
set_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# ------------------------------------------------------------
# 1) Teacher dataset & model
# ------------------------------------------------------------
teacher_root = "./Dataset4" # TODO: adjust for your setup
teacher_train_loader, teacher_val_loader = create_teacher_dataloaders(
root_dir=teacher_root,
batch_size=16,
)
teacher_model = create_teacher_model(num_classes=1)
# Train teacher
t0 = time.perf_counter()
teacher_model = train_teacher_alone(
teacher_model,
teacher_train_loader,
teacher_val_loader,
device,
epochs=12,
lr=1e-4,
weight_decay=1e-4,
)
t1 = time.perf_counter()
print(f"[MAIN] Teacher training time: {t1 - t0:.2f}s")
torch.save(teacher_model.state_dict(), "teacher_only.pth")
print("\n=== [Teacher-Only Model] Final Evaluations ===")
teacher_train_loss, teacher_train_acc, teacher_train_cm = evaluate_teacher_confusion(
teacher_model, teacher_train_loader, device
)
teacher_val_loss, teacher_val_acc, teacher_val_cm = evaluate_teacher_confusion(
teacher_model, teacher_val_loader, device
)
print(f"Teacher [Train] Loss={teacher_train_loss:.4f} Acc={teacher_train_acc:.4f}")
print(f"Confusion Matrix (Train):\n{teacher_train_cm}\n")
print(f"Teacher [Val] Loss={teacher_val_loss:.4f} Acc={teacher_val_acc:.4f}")
print(f"Confusion Matrix (Val):\n{teacher_val_cm}\n")
# ------------------------------------------------------------
# 2) Prior-Current dataset for student
# ------------------------------------------------------------
pc_transform = transforms.Compose(
[
transforms.RandomResizedCrop((512, 512), scale=(0.8, 1.0)),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(5),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(0.1, 0.1, 0.1, 0.05),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225],
),
]
)
pc_dataset_root = "./CurrentPrior512aug2" # TODO: adjust
pc_dataset = PriorCurrentDataset(
root_dir=pc_dataset_root,
transform=pc_transform,
augment=True,
)
pc_size = len(pc_dataset)
train_size = int(0.8 * pc_size)
val_size = pc_size - train_size
indices_pc = list(range(pc_size))
import random
random.shuffle(indices_pc)
train_indices_pc = indices_pc[:train_size]
val_indices_pc = indices_pc[train_size:]
train_dataset_pc = torch.utils.data.Subset(pc_dataset, train_indices_pc)
val_dataset_pc = torch.utils.data.Subset(pc_dataset, val_indices_pc)
pc_train_labels = [pc_dataset.labels[i] for i in train_dataset_pc.indices]
pc_train_sampler = create_weighted_sampler(pc_train_labels)
pc_train_loader = DataLoader(
train_dataset_pc,
batch_size=16,
sampler=pc_train_sampler,
shuffle=False,
drop_last=True,
)
pc_val_loader = DataLoader(
val_dataset_pc,
batch_size=16,
shuffle=False,
)
# ------------------------------------------------------------
# 3) Student model & initial evaluation
# ------------------------------------------------------------
student_model = DualBranchStudent(num_classes=1).to(device)
print("\n=== Student Evaluation BEFORE Distillation ===")
train_loss_alone, train_acc_alone = evaluate_student_alone(
student_model, pc_train_loader, device
)
val_loss_alone, val_acc_alone = evaluate_student_alone(
student_model, pc_val_loader, device
)
print(
f"Student-alone (Untrained) - "
f"Train Loss: {train_loss_alone:.4f}, Train Acc: {train_acc_alone:.4f} | "
f"Val Loss: {val_loss_alone:.4f}, Val Acc: {val_acc_alone:.4f}\n"
)
# Teacher feature extractor for KD
teacher_extractor = TeacherFeatureExtractor(teacher_model)
# ------------------------------------------------------------
# 4) Train student with policy-based teacher updates
# ------------------------------------------------------------
t0 = time.perf_counter()
student_model = train_student_with_teacher(
student_model=student_model,
teacher_model=teacher_model,
teacher_extractor=teacher_extractor,
pc_train_loader=pc_train_loader,
pc_val_loader=pc_val_loader,
device=device,
public_loader=teacher_val_loader, # can also use teacher_train_loader
epochs=20,
alpha=0.9,
beta=0.2,
T=2.0,
gamma=0.6,
student_lr=1e-4,
teacher_lr=1e-6,
weight_decay=1e-4,
tau=0.10,
confidence_threshold=0.70,
WARMUP_EPOCHS=7,
teacher_interp_lambda=0.99,
)
t1 = time.perf_counter()
print(f"[MAIN] Student+Policy training time: {t1 - t0:.2f}s")
torch.save(teacher_model.state_dict(), "teacher_updated.pth")
torch.save(student_model.state_dict(), "student_final.pth")
print("DONE. Teacher → Student training with policy-based updates complete.\n")
# ------------------------------------------------------------
# 5) Final evaluations
# ------------------------------------------------------------
print("=== Evaluate Updated Teacher on Public Dataset ===")
teacher_updated = create_teacher_model(num_classes=1)
teacher_updated.load_state_dict(
torch.load("teacher_updated.pth", map_location=device)
)
teacher_updated.to(device).eval()
updated_train_loss, updated_train_acc, updated_train_cm = evaluate_teacher_confusion(
teacher_updated, teacher_train_loader, device
)
updated_val_loss, updated_val_acc, updated_val_cm = evaluate_teacher_confusion(
teacher_updated, teacher_val_loader, device
)
print(
f"Updated Teacher [Train] Loss={updated_train_loss:.4f} "
f"Acc={updated_train_acc:.4f}"
)
print(f"Confusion Matrix:\n{updated_train_cm}\n")
print(
f"Updated Teacher [Val] Loss={updated_val_loss:.4f} "
f"Acc={updated_val_acc:.4f}"
)
print(f"Confusion Matrix:\n{updated_val_cm}\n")
print("=== Evaluate Student on Prior-Current Dataset ===")
student_loaded = DualBranchStudent(num_classes=1)
student_loaded.load_state_dict(
torch.load("student_final.pth", map_location=device)
)
student_loaded.to(device).eval()
s_train_loss, s_train_acc, s_train_cm = evaluate_student_confusion(
student_loaded, pc_train_loader, device
)
print(f"Student [Train] Loss={s_train_loss:.4f} Acc={s_train_acc:.4f}")
print(f"Confusion Matrix:\n{s_train_cm}\n")
s_val_loss, s_val_acc, s_val_cm = evaluate_student_confusion(
student_loaded, pc_val_loader, device
)
print(f"Student [Val] Loss={s_val_loss:.4f} Acc={s_val_acc:.4f}")
print(f"Confusion Matrix:\n{s_val_cm}\n")