-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodels.py
More file actions
180 lines (147 loc) · 6.12 KB
/
models.py
File metadata and controls
180 lines (147 loc) · 6.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
"""
models.py
Teacher and student architectures:
- ResNet50-based teacher (binary classifier)
- Dual-branch ResNet18 student (prior/current)
- TeacherFeatureExtractor for feature-level distillation
"""
# This is the sample code for the usual dataset in dimensions 512 X 512. For the exact method, you can change it to 1024 X 1024.
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
# ---------------------------------------------------------------------------
# Teacher model
# ---------------------------------------------------------------------------
def create_teacher_model(num_classes: int = 1) -> nn.Module:
"""
Build a ResNet50 teacher model for binary classification.
The final FC is replaced by:
BN -> Dropout -> Linear(num_classes)
"""
model = models.resnet50(pretrained=True)
in_features = model.fc.in_features # type: ignore
model.fc = nn.Sequential( # type: ignore
nn.BatchNorm1d(in_features),
nn.Dropout(0.5),
nn.Linear(in_features, num_classes),
) # type: ignore
return model
def partial_freeze_resnet50(model: nn.Module) -> None:
"""
Freeze all layers of ResNet50 except layer4 + fc.
This is a common strategy when fine-tuning on a small dataset.
"""
for name, param in model.named_parameters():
if "layer4" in name or "fc" in name:
param.requires_grad = True
else:
param.requires_grad = False
# ---------------------------------------------------------------------------
# Student model: Dual-branch ResNet18
# ---------------------------------------------------------------------------
class DualBranchStudent(nn.Module):
"""
Dual-branch ResNet18 student for prior-current mammograms.
- One ResNet18 processes prior images.
- Another ResNet18 processes current images.
- Each branch has its own classification head.
- Features are concatenated and passed to a fusion head that outputs
the final binary prediction.
Returns from forward:
fused_logit, logit_prior, logit_current, feat_prior, feat_current
"""
def __init__(self, num_classes: int = 1):
super().__init__()
# Two separate ResNet18 backbones
self.branch_prior = models.resnet18(pretrained=True)
self.branch_current = models.resnet18(pretrained=True)
self._partial_freeze(self.branch_prior)
self._partial_freeze(self.branch_current)
in_features = self.branch_prior.fc.in_features # type: ignore
# Remove original FC layers; use them as feature extractors
self.branch_prior.fc = nn.Identity() # type: ignore
self.branch_current.fc = nn.Identity() # type: ignore
# Heads for individual branches
self.prior_head = nn.Sequential(
nn.BatchNorm1d(in_features),
nn.Dropout(0.3),
nn.Linear(in_features, num_classes),
)
self.current_head = nn.Sequential(
nn.BatchNorm1d(in_features),
nn.Dropout(0.3),
nn.Linear(in_features, num_classes),
)
# Fusion head over concatenated features
self.fusion = nn.Sequential(
nn.BatchNorm1d(in_features * 2),
nn.Dropout(0.5),
nn.Linear(in_features * 2, num_classes),
)
@staticmethod
def _partial_freeze(backbone: nn.Module) -> None:
"""
Freeze early layers of ResNet18 and fine-tune the last blocks.
This reduces overfitting and training time.
"""
for name, param in backbone.named_parameters():
if any(k in name for k in ("layer3", "layer4")):
param.requires_grad = True
else:
param.requires_grad = False
def forward(
self, prior_img: torch.Tensor, current_img: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# Features from each branch
feat_prior = self.branch_prior(prior_img)
feat_current = self.branch_current(current_img)
# Branch logits
logit_prior = self.prior_head(feat_prior).squeeze(1)
logit_current = self.current_head(feat_current).squeeze(1)
# Fused feature -> fused logit
combined = torch.cat([feat_prior, feat_current], dim=1)
fused_logit = self.fusion(combined).squeeze(1)
return fused_logit, logit_prior, logit_current, feat_prior, feat_current
# ---------------------------------------------------------------------------
# Teacher feature extractor for distillation
# ---------------------------------------------------------------------------
class TeacherFeatureExtractor(nn.Module):
"""
Wrap a ResNet50 teacher to expose intermediate features.
- teacher_body: everything up to layer4
- teacher_fc: final classifier
- teacher_proj: optional linear projection to 512-dim for feature loss
"""
def __init__(self, teacher_model: nn.Module):
super().__init__()
self.teacher_body = nn.Sequential(
teacher_model.conv1, # type: ignore
teacher_model.bn1, # type: ignore
teacher_model.relu, # type: ignore
teacher_model.maxpool, # type: ignore
teacher_model.layer1, # type: ignore
teacher_model.layer2, # type: ignore
teacher_model.layer3, # type: ignore
teacher_model.layer4, # type: ignore
)
self.teacher_fc = teacher_model.fc # type: ignore
self.teacher_proj = nn.Linear(2048, 512)
def forward(
self,
x: torch.Tensor,
return_features: bool = False,
project_features: bool = False,
):
feat_map = self.teacher_body(x)
pooled = F.adaptive_avg_pool2d(feat_map, (1, 1)).view(x.size(0), -1)
logits = self.teacher_fc(pooled).squeeze(1) # type: ignore
if not return_features:
return logits
# Features for KD: either raw 2048-d or projected 512-d
if project_features:
feat_512 = self.teacher_proj(pooled)
return logits, feat_512
else:
return logits, pooled