@@ -1188,24 +1188,26 @@ def __init__(self, chain, uresnet_deghost=None, uresnet_deghost_loss=None,
11881188 process_chain_config (self , ** chain )
11891189
11901190 # Initialize the deghosting loss
1191- if self .deghosting == 'uresnet' :
1191+ if self .deghosting == 'uresnet' and uresnet_deghost_loss is not None :
11921192 self .deghost_loss = SegmentationLoss (
11931193 uresnet_deghost , uresnet_deghost_loss )
11941194
11951195 # Initialize the segmentation/PPN losses
11961196 if self .segmentation == 'uresnet' :
1197- assert ((uresnet_loss is not None ) ^
1198- (uresnet_ppn_loss is not None )), (
1199- "If the segmentation is using UResNet, must provide the "
1197+ assert not ((uresnet_loss is not None ) and
1198+ (uresnet_ppn_loss is not None )), (
1199+ "If the segmentation is using UResNet, can provide either "
12001200 "`uresnet_loss` or `uresnet_ppn_loss` configuration block." )
12011201 if uresnet_loss is not None :
12021202 self .uresnet_loss = SegmentationLoss (uresnet , uresnet_loss )
1203- else :
1203+ elif uresnet_ppn_loss is not None :
12041204 self .uresnet_ppn_loss = UResNetPPNLoss (
12051205 ** uresnet_ppn , ** uresnet_ppn_loss )
12061206
12071207 # Initialize the graph-SPICE loss
1208- if self .fragmentation is not None and 'graph_spice' in self .fragmentation :
1208+ if (self .fragmentation is not None and
1209+ 'graph_spice' in self .fragmentation and
1210+ graph_spice_loss is not None ):
12091211 self .graph_spice_loss = GraphSPICELoss (graph_spice , graph_spice_loss )
12101212
12111213 # Initialize the GraPA lossses
@@ -1214,11 +1216,9 @@ def __init__(self, chain, uresnet_deghost=None, uresnet_deghost_loss=None,
12141216 'particle' : grappa_particle_loss , 'inter' : grappa_inter_loss
12151217 }
12161218 for stage , config in self .grappa_losses .items ():
1217- if getattr (self , f'{ stage } _aggregation' ) == 'grappa' :
1219+ if (getattr (self , f'{ stage } _aggregation' ) == 'grappa' and
1220+ config is not None ):
12181221 name = f'grappa_{ stage } _loss'
1219- assert config is not None , (
1220- f"If the { stage } aggregation is done using GrapPA, "
1221- f"must provide the { name } configuration block." )
12221222 setattr (self , name , GrapPALoss (config ))
12231223
12241224 @property
@@ -1275,7 +1275,7 @@ def forward(self, seg_label=None, ppn_label=None, clust_label=None,
12751275 self .result = {'accuracy' : 1. , 'loss' : 0. , 'num_losses' : 0 }
12761276
12771277 # Apply the deghosting loss
1278- if self .deghosting == 'uresnet' :
1278+ if self .deghosting == 'uresnet' and hasattr ( self , 'deghost_loss' ) :
12791279 # Convert segmentation labels to ghost labels
12801280 ghost_label_tensor = seg_label .tensor .clone ()
12811281 ghost_label_tensor [:, SHAPE_COL ] = (
@@ -1307,21 +1307,23 @@ def forward(self, seg_label=None, ppn_label=None, clust_label=None,
13071307 # reconstructed semantic segmentation of the image
13081308 clust_label = clust_label_adapt
13091309
1310- # Store the loss dictionary
1310+ # Store the loss dictionary, if requested
13111311 if hasattr (self , 'uresnet_loss' ):
13121312 res_seg = self .uresnet_loss (
13131313 seg_label = seg_label , segmentation = segmentation )
13141314 self .update_result (res_seg , 'uresnet' )
13151315
1316- else :
1316+ elif hasattr ( self , 'uresnet_ppn_loss' ) :
13171317 res_seg = self .uresnet_ppn_loss (
13181318 seg_label = seg_label , ppn_label = ppn_label ,
13191319 clust_label = clust_label , segmentation = segmentation ,
13201320 ** output )
13211321 self .update_result (res_seg )
13221322
13231323 # Apply the Graph-SPICE loss
1324- if self .fragmentation is not None and 'graph_spice' in self .fragmentation :
1324+ if (self .fragmentation is not None and
1325+ 'graph_spice' in self .fragmentation and
1326+ hasattr (self , 'graph_spice_loss' )):
13251327 # Prepare Graph-SPICE loss input
13261328 loss_dict = {}
13271329 for key , value in output .items ():
@@ -1336,9 +1338,10 @@ def forward(self, seg_label=None, ppn_label=None, clust_label=None,
13361338
13371339 # Apply the aggregation losses
13381340 for stage in self .grappa_losses .keys ():
1339- if getattr (self , f'{ stage } _aggregation' ) == 'grappa' :
1341+ name = f'grappa_{ stage } _loss'
1342+ if (getattr (self , f'{ stage } _aggregation' ) == 'grappa' and
1343+ hasattr (self , name )):
13401344 # Prepare the input to the loss function
1341- name = f'grappa_{ stage } _loss'
13421345 prefix = f'{ stage } _fragment' if stage != 'inter' else 'particle'
13431346 loss_dict = {}
13441347 for k , v in output .items ():
0 commit comments