Teacher in the Loop, Student on Time: Policy-Guided Continual Learning from Prior–Current Mammograms
Teacher in the Loop, Student on Time: Policy-Guided Continual Learning from Prior–Current Mammograms
This repository contains the official implementation of the methods described in:
Teacher in the Loop, Student on Time: Policy-Guided Continual Learning from Prior–Current Mammograms
Authors:
[Sahand Hamzehei, Afsana Ahsan Jeny, Mostafa Karami, Stephen Andrew Baker, Tucker Van Rathe, Clifford Yang, Sheida Nabavi]
Submitted to:
[IEEE Transactions on Medical Imaging]
Year: [2025]
The code implements a teacher–student framework for mammography, where:
- A ResNet50 teacher is trained on a public mammography dataset (Normal vs Abnormal).
- A dual-branch ResNet18 student learns from paired prior–current exams.
- Training uses:
- Knowledge distillation on logits and features,
- Gradient alignment,
- Policy-based teacher updates (only when student validation metrics justify it),
- Replay on teacher data to avoid catastrophic forgetting,
- Regularization (EWC-style + covariance/CMD-style).
.
├─ datasets.py # Dataset definitions and teacher dataloaders
├─ models.py # Teacher model, student model, feature extractor
├─ losses.py # Distillation, feature, and gradient-related losses
├─ train.py # End-to-end training and evaluation script
├─ gradcam_explain.py # Grad-CAM visualization for teacher and student
└─ README.md
Tested with:
- Python ≥ 3.8
- PyTorch ≥ 1.13 (or ≥ 2.0)
- CUDA (optional but recommended)
You can install dependencies with:
pip install torch torchvision
pip install numpy scikit-learn pillow matplotlib opencv-python
The teacher trains on a binary classification dataset with two folders:
DATASET_TEACHER/
├─ normal/
│ ├─ img1.tif
│ ├─ img2.tif
│ └─ ...
└─ mass/
├─ img3.tif
├─ img4.tif
└─ ...
Normal images → label 0
Abnormal images → label 1
In train.py, this path is controlled via:
teacher_root = "./Dataset Path" # change this to your teacher dataset root
The student uses paired prior–current exams for each patient. Expected structure:
DATASET_PRIOR_CURRENT/
└─ PATIENT001/
├─ PATIENT001_prior/ # folder name ends with 'prior' or 'p'
│ ├─ ...LCC.tif
│ ├─ ...LMLO.tif
│ └─ ...
└─ PATIENT001_current/ # folder name ends with 'current' or 'c'
├─ ...LCC_MASS.tif
├─ ...LMLO_NORMAL.tif
└─ ...
Key assumptions (implemented in PriorCurrentDataset in datasets.py):
- Prior and current views are matched using view identifiers such as LCC, LMLO, RCC, RMLO inside the filenames.
- The label is derived from the current exam filename using keywords in abnormal_terms (default: ["MASS", "Mass", "ARCH", "Arch", "CALC", "Calc"]):
- If any keyword appears in the current filename → label 1 (abnormal)
- Otherwise → label 0 (normal)
In train.py, this dataset root is set by:
pc_dataset_root = "./Name of Folder" # change this to your prior-current dataset root
If your filenames or folder names use different patterns, you typically only need to edit:
- abnormal_terms in PriorCurrentDataset.init
- PriorCurrentDataset._extract_identifier() (how views like CC/MLO are detected)
- The logic in _prepare_pairs() that detects prior/current subfolders (currently checks for suffixes like "prior", "p", "current", "c").
All such logic is localized inside datasets.py and documented with comments.
All training code lives in train.py.
- Train the Teacher
This trains a ResNet50 on the public teacher dataset and saves teacher_only.pth:
python train.pyBy default, train.py will:- Load the teacher dataset from teacher_root.
- Train the teacher with partial freezing (only layer4 and fc are fine-tuned).
- Save the weights to:
teacher_only.pth - Print training/validation accuracy and confusion matrices.
Note
The same script also continues to the student training step (below). If you only want the teacher, you can early return or comment out the student part.
- Train the Student with Policy-Based Teacher Updates In the second part of train.py, the script:
-
Builds the prior–current dataset (PriorCurrentDataset) and splits into train/val.
-
Initializes the dual-branch ResNet18 student (DualBranchStudent).
-
Wraps the teacher into a feature extractor (TeacherFeatureExtractor).
-
Runs train_student_with_teacher() which:
- Uses composite distillation loss:
- hard labels on fused output,
- teacher–student logits (prior & current),
- feature distillation (MSE),
- student–student consistency term.
-
Aligns gradients with a “teacher-direction” gradient.
-
Updates the teacher only conditionally based on:
- student validation loss improvement,
- train/val accuracy gap,
- stability of validation accuracy over a sliding window.
-
Uses replay from the teacher dataset to regularize the teacher.
-
Adds CMD-style covariance regularization and EWC-style weight drift penalty.
At the end, the script saves:
teacher_updated.pth # teacher after policy-based updates
student_final.pth # final student model
Hyperparameters (e.g. alpha, beta, gamma, T, tau, WARMUP_EPOCHS) are all specified in the train_student_with_teacher call near the bottom of train.py and can be changed there.
