@@ -1005,33 +1005,72 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> STEP_
10051005 ds_weights = [1.0 ] + [0.5 ** i for i in range (1 , len (ds_outputs ) + 1 )]
10061006 all_outputs = [main_output ] + ds_outputs
10071007
1008+ # Check if multi-task learning is configured
1009+ is_multi_task = hasattr (self .cfg .model , 'multi_task_config' ) and self .cfg .model .multi_task_config is not None
1010+
10081011 for scale_idx , (output , ds_weight ) in enumerate (zip (all_outputs , ds_weights )):
10091012 # Match target to output size
10101013 target = self ._match_target_to_output (labels , output )
10111014
10121015 # Compute loss for this scale
10131016 scale_loss = 0.0
1014- for loss_fn , weight in zip (self .loss_functions , self .loss_weights ):
1015- loss = loss_fn (output , target )
10161017
1017- # Check for NaN/Inf immediately after computing loss
1018- if self .enable_nan_detection and (torch .isnan (loss ) or torch .isinf (loss )):
1019- print (f"\n { '=' * 80 } " )
1020- print (f"⚠️ NaN/Inf detected in loss computation!" )
1021- print (f"{ '=' * 80 } " )
1022- print (f"Loss function: { loss_fn .__class__ .__name__ } " )
1023- print (f"Loss value: { loss .item ()} " )
1024- print (f"Scale: { scale_idx } , Weight: { weight } " )
1025- print (f"Output shape: { output .shape } , range: [{ output .min ():.4f} , { output .max ():.4f} ]" )
1026- print (f"Target shape: { target .shape } , range: [{ target .min ():.4f} , { target .max ():.4f} ]" )
1027- print (f"Output contains NaN: { torch .isnan (output ).any ()} " )
1028- print (f"Target contains NaN: { torch .isnan (target ).any ()} " )
1029- if self .debug_on_nan :
1030- print (f"\n Entering debugger..." )
1031- pdb .set_trace ()
1032- raise ValueError (f"NaN/Inf in loss at scale { scale_idx } " )
1033-
1034- scale_loss += loss * weight
1018+ if is_multi_task :
1019+ # Multi-task learning with deep supervision:
1020+ # Apply specific losses to specific channels at each scale
1021+ for task_idx , task_config in enumerate (self .cfg .model .multi_task_config ):
1022+ start_ch , end_ch , task_name , loss_indices = task_config
1023+
1024+ # Extract channels for this task
1025+ task_output = output [:, start_ch :end_ch , ...]
1026+ task_target = target [:, start_ch :end_ch , ...]
1027+
1028+ # Apply specified losses for this task
1029+ for loss_idx in loss_indices :
1030+ loss_fn = self .loss_functions [loss_idx ]
1031+ weight = self .loss_weights [loss_idx ]
1032+
1033+ loss = loss_fn (task_output , task_target )
1034+
1035+ # Check for NaN/Inf
1036+ if self .enable_nan_detection and (torch .isnan (loss ) or torch .isinf (loss )):
1037+ print (f"\n { '=' * 80 } " )
1038+ print (f"⚠️ NaN/Inf detected in deep supervision multi-task loss!" )
1039+ print (f"{ '=' * 80 } " )
1040+ print (f"Scale: { scale_idx } , Task: { task_name } (channels { start_ch } :{ end_ch } )" )
1041+ print (f"Loss function: { loss_fn .__class__ .__name__ } (index { loss_idx } )" )
1042+ print (f"Loss value: { loss .item ()} " )
1043+ print (f"Output shape: { task_output .shape } , range: [{ task_output .min ():.4f} , { task_output .max ():.4f} ]" )
1044+ print (f"Target shape: { task_target .shape } , range: [{ task_target .min ():.4f} , { task_target .max ():.4f} ]" )
1045+ if self .debug_on_nan :
1046+ print (f"\n Entering debugger..." )
1047+ pdb .set_trace ()
1048+ raise ValueError (f"NaN/Inf in deep supervision loss at scale { scale_idx } , task { task_name } " )
1049+
1050+ scale_loss += loss * weight
1051+ else :
1052+ # Standard deep supervision: apply all losses to all outputs
1053+ for loss_fn , weight in zip (self .loss_functions , self .loss_weights ):
1054+ loss = loss_fn (output , target )
1055+
1056+ # Check for NaN/Inf immediately after computing loss
1057+ if self .enable_nan_detection and (torch .isnan (loss ) or torch .isinf (loss )):
1058+ print (f"\n { '=' * 80 } " )
1059+ print (f"⚠️ NaN/Inf detected in loss computation!" )
1060+ print (f"{ '=' * 80 } " )
1061+ print (f"Loss function: { loss_fn .__class__ .__name__ } " )
1062+ print (f"Loss value: { loss .item ()} " )
1063+ print (f"Scale: { scale_idx } , Weight: { weight } " )
1064+ print (f"Output shape: { output .shape } , range: [{ output .min ():.4f} , { output .max ():.4f} ]" )
1065+ print (f"Target shape: { target .shape } , range: [{ target .min ():.4f} , { target .max ():.4f} ]" )
1066+ print (f"Output contains NaN: { torch .isnan (output ).any ()} " )
1067+ print (f"Target contains NaN: { torch .isnan (target ).any ()} " )
1068+ if self .debug_on_nan :
1069+ print (f"\n Entering debugger..." )
1070+ pdb .set_trace ()
1071+ raise ValueError (f"NaN/Inf in loss at scale { scale_idx } " )
1072+
1073+ scale_loss += loss * weight
10351074
10361075 total_loss += scale_loss * ds_weight
10371076 loss_dict [f'train_loss_scale_{ scale_idx } ' ] = scale_loss .item ()
@@ -1100,15 +1139,38 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> STE
11001139 ds_weights = [1.0 ] + [0.5 ** i for i in range (1 , len (ds_outputs ) + 1 )]
11011140 all_outputs = [main_output ] + ds_outputs
11021141
1142+ # Check if multi-task learning is configured
1143+ is_multi_task = hasattr (self .cfg .model , 'multi_task_config' ) and self .cfg .model .multi_task_config is not None
1144+
11031145 for scale_idx , (output , ds_weight ) in enumerate (zip (all_outputs , ds_weights )):
11041146 # Match target to output size
11051147 target = self ._match_target_to_output (labels , output )
11061148
11071149 # Compute loss for this scale
11081150 scale_loss = 0.0
1109- for loss_fn , weight in zip (self .loss_functions , self .loss_weights ):
1110- loss = loss_fn (output , target )
1111- scale_loss += loss * weight
1151+
1152+ if is_multi_task :
1153+ # Multi-task learning with deep supervision:
1154+ # Apply specific losses to specific channels at each scale
1155+ for task_idx , task_config in enumerate (self .cfg .model .multi_task_config ):
1156+ start_ch , end_ch , task_name , loss_indices = task_config
1157+
1158+ # Extract channels for this task
1159+ task_output = output [:, start_ch :end_ch , ...]
1160+ task_target = target [:, start_ch :end_ch , ...]
1161+
1162+ # Apply specified losses for this task
1163+ for loss_idx in loss_indices :
1164+ loss_fn = self .loss_functions [loss_idx ]
1165+ weight = self .loss_weights [loss_idx ]
1166+
1167+ loss = loss_fn (task_output , task_target )
1168+ scale_loss += loss * weight
1169+ else :
1170+ # Standard deep supervision: apply all losses to all outputs
1171+ for loss_fn , weight in zip (self .loss_functions , self .loss_weights ):
1172+ loss = loss_fn (output , target )
1173+ scale_loss += loss * weight
11121174
11131175 total_loss += scale_loss * ds_weight
11141176 loss_dict [f'val_loss_scale_{ scale_idx } ' ] = scale_loss .item ()
@@ -1367,6 +1429,10 @@ def _match_target_to_output(
13671429 For segmentation masks, uses nearest-neighbor interpolation to preserve labels.
13681430 For continuous targets, uses trilinear interpolation.
13691431
1432+ IMPORTANT: For continuous targets in range [-1, 1] (e.g., tanh-normalized SDT),
1433+ trilinear interpolation can cause overshooting beyond bounds. We clamp the
1434+ resized targets back to [-1, 1] to prevent loss explosion.
1435+
13701436 Args:
13711437 target: Target tensor of shape (B, C, D, H, W)
13721438 output: Output tensor of shape (B, C, D', H', W')
@@ -1396,6 +1462,18 @@ def _match_target_to_output(
13961462 align_corners = False ,
13971463 )
13981464
1465+ # CRITICAL FIX: Clamp resized targets to prevent interpolation overshooting
1466+ # For targets in range [-1, 1] (e.g., tanh-normalized SDT), trilinear interpolation
1467+ # can produce values outside this range (e.g., -1.2, 1.3) which causes loss explosion
1468+ # when used with tanh-activated predictions.
1469+ # Check if targets are in typical normalized ranges:
1470+ if target .min () >= - 1.5 and target .max () <= 1.5 :
1471+ # Likely normalized to [-1, 1] (with some tolerance for existing overshoots)
1472+ target_resized = torch .clamp (target_resized , - 1.0 , 1.0 )
1473+ elif target .min () >= 0.0 and target .max () <= 1.5 :
1474+ # Likely normalized to [0, 1]
1475+ target_resized = torch .clamp (target_resized , 0.0 , 1.0 )
1476+
13991477 return target_resized
14001478
14011479 def configure_optimizers (self ) -> Dict [str , Any ]:
0 commit comments