|
4 | 4 | import torch |
5 | 5 | import torch.nn as nn |
6 | 6 |
|
| 7 | +from torchvision.transforms.functional import normalize |
| 8 | +from torchvision.models import vgg16_bn |
| 9 | + |
| 10 | +from robosat.hooks import FeatureHook |
| 11 | + |
7 | 12 |
|
8 | 13 | class CrossEntropyLoss2d(nn.Module): |
9 | 14 | """Cross-entropy. |
@@ -117,3 +122,123 @@ def forward(self, inputs, targets): |
117 | 122 | loss += torch.dot(nn.functional.relu(errors_sorted), iou) |
118 | 123 |
|
119 | 124 | return loss / N |
| 125 | + |
| 126 | + |
| 127 | +class CombinedLoss(nn.Module): |
| 128 | + """Weighted combination of losses. |
| 129 | + """ |
| 130 | + |
| 131 | + def __init__(self, criteria, weights): |
| 132 | + """Creates a `CombinedLosses` instance. |
| 133 | +
|
| 134 | + Args: |
| 135 | + criteria: list of criteria to combine |
| 136 | + weights: tensor to tune losses with |
| 137 | + """ |
| 138 | + |
| 139 | + super().__init__() |
| 140 | + |
| 141 | + assert len(weights.size()) == 1 |
| 142 | + assert weights.size(0) == len(criteria) |
| 143 | + |
| 144 | + self.criteria = criteria |
| 145 | + self.weights = weights |
| 146 | + |
| 147 | + def forward(self, inputs, targets): |
| 148 | + loss = 0.0 |
| 149 | + |
| 150 | + for criterion, w in zip(self.criteria, self.weights): |
| 151 | + each = w * criterion(inputs, targets) |
| 152 | + print(type(criterion).__name__, each.item()) # Todo: remove |
| 153 | + loss += each |
| 154 | + |
| 155 | + return loss |
| 156 | + |
| 157 | + |
| 158 | +class TopologyLoss(nn.Module): |
| 159 | + """Topology loss working on a pre-trained model's feature map similarities. |
| 160 | +
|
| 161 | + See: |
| 162 | + - https://arxiv.org/abs/1603.08155 |
| 163 | + - https://arxiv.org/abs/1712.02190 |
| 164 | +
|
| 165 | + Note: implementation works with single channel tensors and stacks them for VGG. |
| 166 | + """ |
| 167 | + |
| 168 | + def __init__(self, blocks, weights): |
| 169 | + """Creates a `TopologyLoss` instance. |
| 170 | +
|
| 171 | + Args: |
| 172 | + blocks: list of block indices to use, in `[0, 6]` (e.g. `[0, 1, 2]`) |
| 173 | + weights: tensor to tune losses per block (e.g. `[0.2, 0.6, 0.2]`) |
| 174 | +
|
| 175 | + Note: the block indices correspond to a pre-trained VGG's feature maps to use. |
| 176 | + """ |
| 177 | + |
| 178 | + super().__init__() |
| 179 | + |
| 180 | + assert len(weights.size()) == 1 |
| 181 | + assert weights.size(0) == len(blocks) |
| 182 | + |
| 183 | + self.weights = weights |
| 184 | + |
| 185 | + assert len(blocks) <= 5 |
| 186 | + assert all(i in range(5) for i in blocks) |
| 187 | + assert sorted(blocks) == blocks |
| 188 | + |
| 189 | + features = vgg16_bn(pretrained=True).features |
| 190 | + features.eval() |
| 191 | + |
| 192 | + for param in features.parameters(): |
| 193 | + param.requires_grad = False |
| 194 | + |
| 195 | + relus = [i - 1 for i, m in enumerate(features) if isinstance(m, nn.MaxPool2d)] |
| 196 | + |
| 197 | + self.hooks = [FeatureHook(features[relus[i]]) for i in blocks] |
| 198 | + |
| 199 | + # Trim off unused layers to make forward pass more efficient |
| 200 | + self.features = features[0 : relus[blocks[-1]] + 1] |
| 201 | + |
| 202 | + def forward(self, inputs, targets): |
| 203 | + # model output to foreground probabilities |
| 204 | + inputs = nn.functional.softmax(inputs, dim=1) |
| 205 | + # we need to clone the tensor here before slicing otherwise pytorch |
| 206 | + # will lose track of information required for gradient computation |
| 207 | + inputs = inputs.clone()[:, 1, :, :] |
| 208 | + |
| 209 | + # masks are longs but vgg wants floats |
| 210 | + targets = targets.float() |
| 211 | + |
| 212 | + # normalize foreground pixels to ImageNet statistics for pre-trained VGG |
| 213 | + mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225] |
| 214 | + inputs = normalize(inputs, mean, std) |
| 215 | + targets = normalize(targets, mean, std) |
| 216 | + |
| 217 | + # N, H, W -> N, C, H, W |
| 218 | + inputs = inputs.unsqueeze(1) |
| 219 | + targets = targets.unsqueeze(1) |
| 220 | + |
| 221 | + # repeat channel three times for using a pre-trained three-channel VGG |
| 222 | + inputs = inputs.repeat(1, 3, 1, 1) |
| 223 | + targets = targets.repeat(1, 3, 1, 1) |
| 224 | + |
| 225 | + # extract feature maps and compare their weighted loss |
| 226 | + |
| 227 | + self.features(inputs) |
| 228 | + input_features = [hook.features.clone() for hook in self.hooks] |
| 229 | + |
| 230 | + self.features(targets) |
| 231 | + target_features = [hook.features for hook in self.hooks] |
| 232 | + |
| 233 | + loss = 0.0 |
| 234 | + |
| 235 | + for lhs, rhs, w in zip(input_features, target_features, self.weights): |
| 236 | + lhs = lhs.view(lhs.size(0), -1) |
| 237 | + rhs = rhs.view(rhs.size(0), -1) |
| 238 | + loss += nn.functional.l1_loss(lhs, rhs) * w |
| 239 | + |
| 240 | + return loss |
| 241 | + |
| 242 | + def close(self): |
| 243 | + for hook in self.hooks: |
| 244 | + hook.close() |
0 commit comments