Skip to content

Commit 08b2fc3

Browse files
Allow for full chain not to compute certain losses during training
1 parent f2770ba commit 08b2fc3

File tree

1 file changed

+19
-16
lines changed

1 file changed

+19
-16
lines changed

spine/model/full_chain.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1188,24 +1188,26 @@ def __init__(self, chain, uresnet_deghost=None, uresnet_deghost_loss=None,
11881188
process_chain_config(self, **chain)
11891189

11901190
# Initialize the deghosting loss
1191-
if self.deghosting == 'uresnet':
1191+
if self.deghosting == 'uresnet' and uresnet_deghost_loss is not None:
11921192
self.deghost_loss = SegmentationLoss(
11931193
uresnet_deghost, uresnet_deghost_loss)
11941194

11951195
# Initialize the segmentation/PPN losses
11961196
if self.segmentation == 'uresnet':
1197-
assert ((uresnet_loss is not None) ^
1198-
(uresnet_ppn_loss is not None)), (
1199-
"If the segmentation is using UResNet, must provide the "
1197+
assert not ((uresnet_loss is not None) and
1198+
(uresnet_ppn_loss is not None)), (
1199+
"If the segmentation is using UResNet, can provide either "
12001200
"`uresnet_loss` or `uresnet_ppn_loss` configuration block.")
12011201
if uresnet_loss is not None:
12021202
self.uresnet_loss = SegmentationLoss(uresnet, uresnet_loss)
1203-
else:
1203+
elif uresnet_ppn_loss is not None:
12041204
self.uresnet_ppn_loss = UResNetPPNLoss(
12051205
**uresnet_ppn, **uresnet_ppn_loss)
12061206

12071207
# Initialize the graph-SPICE loss
1208-
if self.fragmentation is not None and 'graph_spice' in self.fragmentation:
1208+
if (self.fragmentation is not None and
1209+
'graph_spice' in self.fragmentation and
1210+
graph_spice_loss is not None):
12091211
self.graph_spice_loss = GraphSPICELoss(graph_spice, graph_spice_loss)
12101212

12111213
# Initialize the GraPA lossses
@@ -1214,11 +1216,9 @@ def __init__(self, chain, uresnet_deghost=None, uresnet_deghost_loss=None,
12141216
'particle': grappa_particle_loss, 'inter': grappa_inter_loss
12151217
}
12161218
for stage, config in self.grappa_losses.items():
1217-
if getattr(self, f'{stage}_aggregation') == 'grappa':
1219+
if (getattr(self, f'{stage}_aggregation') == 'grappa' and
1220+
config is not None):
12181221
name = f'grappa_{stage}_loss'
1219-
assert config is not None, (
1220-
f"If the {stage} aggregation is done using GrapPA, "
1221-
f"must provide the {name} configuration block.")
12221222
setattr(self, name, GrapPALoss(config))
12231223

12241224
@property
@@ -1275,7 +1275,7 @@ def forward(self, seg_label=None, ppn_label=None, clust_label=None,
12751275
self.result = {'accuracy': 1., 'loss': 0., 'num_losses': 0}
12761276

12771277
# Apply the deghosting loss
1278-
if self.deghosting == 'uresnet':
1278+
if self.deghosting == 'uresnet' and hasattr(self, 'deghost_loss'):
12791279
# Convert segmentation labels to ghost labels
12801280
ghost_label_tensor = seg_label.tensor.clone()
12811281
ghost_label_tensor[:, SHAPE_COL] = (
@@ -1307,21 +1307,23 @@ def forward(self, seg_label=None, ppn_label=None, clust_label=None,
13071307
# reconstructed semantic segmentation of the image
13081308
clust_label = clust_label_adapt
13091309

1310-
# Store the loss dictionary
1310+
# Store the loss dictionary, if requested
13111311
if hasattr(self, 'uresnet_loss'):
13121312
res_seg = self.uresnet_loss(
13131313
seg_label=seg_label, segmentation=segmentation)
13141314
self.update_result(res_seg, 'uresnet')
13151315

1316-
else:
1316+
elif hasattr(self, 'uresnet_ppn_loss'):
13171317
res_seg = self.uresnet_ppn_loss(
13181318
seg_label=seg_label, ppn_label=ppn_label,
13191319
clust_label=clust_label, segmentation=segmentation,
13201320
**output)
13211321
self.update_result(res_seg)
13221322

13231323
# Apply the Graph-SPICE loss
1324-
if self.fragmentation is not None and 'graph_spice' in self.fragmentation:
1324+
if (self.fragmentation is not None and
1325+
'graph_spice' in self.fragmentation and
1326+
hasattr(self, 'graph_spice_loss')):
13251327
# Prepare Graph-SPICE loss input
13261328
loss_dict = {}
13271329
for key, value in output.items():
@@ -1336,9 +1338,10 @@ def forward(self, seg_label=None, ppn_label=None, clust_label=None,
13361338

13371339
# Apply the aggregation losses
13381340
for stage in self.grappa_losses.keys():
1339-
if getattr(self, f'{stage}_aggregation') == 'grappa':
1341+
name = f'grappa_{stage}_loss'
1342+
if (getattr(self, f'{stage}_aggregation') == 'grappa' and
1343+
hasattr(self, name)):
13401344
# Prepare the input to the loss function
1341-
name = f'grappa_{stage}_loss'
13421345
prefix = f'{stage}_fragment' if stage != 'inter' else 'particle'
13431346
loss_dict = {}
13441347
for k, v in output.items():

0 commit comments

Comments
 (0)