Skip to content

Commit 424b91f

Browse files
committed
Implements topology-aware loss function; resolves #133
1 parent b7f6ebf commit 424b91f

File tree

3 files changed

+146
-1
lines changed

3 files changed

+146
-1
lines changed

robosat/hooks.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
class FeatureHook:
2+
def __init__(self, module):
3+
self.features = None
4+
self.hook = module.register_forward_hook(self.on)
5+
6+
def on(self, module, inputs, outputs):
7+
self.features = outputs
8+
9+
def close(self):
10+
self.hook.remove()

robosat/losses.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
import torch
55
import torch.nn as nn
66

7+
from torchvision.transforms.functional import normalize
8+
from torchvision.models import vgg16_bn
9+
10+
from robosat.hooks import FeatureHook
11+
712

813
class CrossEntropyLoss2d(nn.Module):
914
"""Cross-entropy.
@@ -117,3 +122,123 @@ def forward(self, inputs, targets):
117122
loss += torch.dot(nn.functional.relu(errors_sorted), iou)
118123

119124
return loss / N
125+
126+
127+
class CombinedLoss(nn.Module):
128+
"""Weighted combination of losses.
129+
"""
130+
131+
def __init__(self, criteria, weights):
132+
"""Creates a `CombinedLosses` instance.
133+
134+
Args:
135+
criteria: list of criteria to combine
136+
weights: tensor to tune losses with
137+
"""
138+
139+
super().__init__()
140+
141+
assert len(weights.size()) == 1
142+
assert weights.size(0) == len(criteria)
143+
144+
self.criteria = criteria
145+
self.weights = weights
146+
147+
def forward(self, inputs, targets):
148+
loss = 0.0
149+
150+
for criterion, w in zip(self.criteria, self.weights):
151+
each = w * criterion(inputs, targets)
152+
print(type(criterion).__name__, each.item()) # Todo: remove
153+
loss += each
154+
155+
return loss
156+
157+
158+
class TopologyLoss(nn.Module):
159+
"""Topology loss working on a pre-trained model's feature map similarities.
160+
161+
See:
162+
- https://arxiv.org/abs/1603.08155
163+
- https://arxiv.org/abs/1712.02190
164+
165+
Note: implementation works with single channel tensors and stacks them for VGG.
166+
"""
167+
168+
def __init__(self, blocks, weights):
169+
"""Creates a `TopologyLoss` instance.
170+
171+
Args:
172+
blocks: list of block indices to use, in `[0, 6]` (e.g. `[0, 1, 2]`)
173+
weights: tensor to tune losses per block (e.g. `[0.2, 0.6, 0.2]`)
174+
175+
Note: the block indices correspond to a pre-trained VGG's feature maps to use.
176+
"""
177+
178+
super().__init__()
179+
180+
assert len(weights.size()) == 1
181+
assert weights.size(0) == len(blocks)
182+
183+
self.weights = weights
184+
185+
assert len(blocks) <= 5
186+
assert all(i in range(5) for i in blocks)
187+
assert sorted(blocks) == blocks
188+
189+
features = vgg16_bn(pretrained=True).features
190+
features.eval()
191+
192+
for param in features.parameters():
193+
param.requires_grad = False
194+
195+
relus = [i - 1 for i, m in enumerate(features) if isinstance(m, nn.MaxPool2d)]
196+
197+
self.hooks = [FeatureHook(features[relus[i]]) for i in blocks]
198+
199+
# Trim off unused layers to make forward pass more efficient
200+
self.features = features[0 : relus[blocks[-1]] + 1]
201+
202+
def forward(self, inputs, targets):
203+
# model output to foreground probabilities
204+
inputs = nn.functional.softmax(inputs, dim=1)
205+
# we need to clone the tensor here before slicing otherwise pytorch
206+
# will lose track of information required for gradient computation
207+
inputs = inputs.clone()[:, 1, :, :]
208+
209+
# masks are longs but vgg wants floats
210+
targets = targets.float()
211+
212+
# normalize foreground pixels to ImageNet statistics for pre-trained VGG
213+
mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]
214+
inputs = normalize(inputs, mean, std)
215+
targets = normalize(targets, mean, std)
216+
217+
# N, H, W -> N, C, H, W
218+
inputs = inputs.unsqueeze(1)
219+
targets = targets.unsqueeze(1)
220+
221+
# repeat channel three times for using a pre-trained three-channel VGG
222+
inputs = inputs.repeat(1, 3, 1, 1)
223+
targets = targets.repeat(1, 3, 1, 1)
224+
225+
# extract feature maps and compare their weighted loss
226+
227+
self.features(inputs)
228+
input_features = [hook.features.clone() for hook in self.hooks]
229+
230+
self.features(targets)
231+
target_features = [hook.features for hook in self.hooks]
232+
233+
loss = 0.0
234+
235+
for lhs, rhs, w in zip(input_features, target_features, self.weights):
236+
lhs = lhs.view(lhs.size(0), -1)
237+
rhs = rhs.view(rhs.size(0), -1)
238+
loss += nn.functional.l1_loss(lhs, rhs) * w
239+
240+
return loss
241+
242+
def close(self):
243+
for hook in self.hooks:
244+
hook.close()

robosat/tools/train.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
)
2727
from robosat.datasets import SlippyMapTilesConcatenation
2828
from robosat.metrics import Metrics
29-
from robosat.losses import CrossEntropyLoss2d, mIoULoss2d, FocalLoss2d, LovaszLoss2d
29+
from robosat.losses import CrossEntropyLoss2d, mIoULoss2d, FocalLoss2d, LovaszLoss2d, CombinedLoss, TopologyLoss
3030
from robosat.unet import UNet
3131
from robosat.utils import plot
3232
from robosat.config import load_config
@@ -108,6 +108,14 @@ def map_location(storage, _):
108108
else:
109109
sys.exit("Error: Unknown [opt][loss] value !")
110110

111+
# use first three vgg feature maps and weight their contribution to loss
112+
topology_weights = torch.tensor([0.2, 0.6, 0.2]).to(device)
113+
topology_loss = TopologyLoss([0, 1, 2], topology_weights).to(device)
114+
115+
# combine the pixel-wise and the topology loss and weight them
116+
loss_weights = torch.tensor([1.0, 10.0]).to(device)
117+
criterion = CombinedLoss([criterion, topology_loss], loss_weights).to(device)
118+
111119
train_loader, val_loader = get_dataset_loaders(model, dataset, args.workers)
112120

113121
num_epochs = model["opt"]["epochs"]
@@ -163,6 +171,8 @@ def map_location(storage, _):
163171

164172
torch.save(states, os.path.join(model["common"]["checkpoint"], checkpoint))
165173

174+
topology_loss.close()
175+
166176

167177
def train(loader, num_classes, device, net, optimizer, criterion):
168178
num_samples = 0

0 commit comments

Comments
 (0)