@@ -156,6 +156,15 @@ def _setup_sliding_window_inferer(self):
156156 if not hasattr (self .cfg , 'inference' ):
157157 return
158158
159+ # For 2D models with do_2d=True, disable sliding window inference
160+ if getattr (self .cfg .data , 'do_2d' , False ):
161+ warnings .warn (
162+ "Sliding-window inference disabled for 2D models with do_2d=True. "
163+ "Using direct inference instead." ,
164+ UserWarning ,
165+ )
166+ return
167+
159168 roi_size = self ._resolve_inferer_roi_size ()
160169 if roi_size is None :
161170 warnings .warn (
@@ -200,12 +209,20 @@ def _resolve_inferer_roi_size(self) -> Optional[Tuple[int, ...]]:
200209 if hasattr (self .cfg , 'model' ) and hasattr (self .cfg .model , 'output_size' ):
201210 output_size = getattr (self .cfg .model , 'output_size' , None )
202211 if output_size :
203- return tuple (int (v ) for v in output_size )
212+ roi_size = tuple (int (v ) for v in output_size )
213+ # For 2D models with do_2d=True, convert to 3D ROI size
214+ if getattr (self .cfg .data , 'do_2d' , False ) and len (roi_size ) == 2 :
215+ roi_size = (1 ,) + roi_size # Add depth dimension
216+ return roi_size
204217
205218 if hasattr (self .cfg , 'data' ) and hasattr (self .cfg .data , 'patch_size' ):
206219 patch_size = getattr (self .cfg .data , 'patch_size' , None )
207220 if patch_size :
208- return tuple (int (v ) for v in patch_size )
221+ roi_size = tuple (int (v ) for v in patch_size )
222+ # For 2D models with do_2d=True, convert to 3D ROI size
223+ if getattr (self .cfg .data , 'do_2d' , False ) and len (roi_size ) == 2 :
224+ roi_size = (1 ,) + roi_size # Add depth dimension
225+ return roi_size
209226
210227 return None
211228
@@ -367,6 +384,10 @@ def _predict_with_tta(self, images: torch.Tensor, mask: Optional[torch.Tensor] =
367384 f"Expected shapes: (D, H, W), (B, D, H, W), or (B, C, D, H, W)"
368385 )
369386
387+ # For 2D models with do_2d=True, squeeze the depth dimension if present
388+ if getattr (self .cfg .data , 'do_2d' , False ) and images .size (2 ) == 1 : # [B, C, 1, H, W] -> [B, C, H, W]
389+ images = images .squeeze (2 )
390+
370391 # Get TTA configuration (default to no augmentation if not configured)
371392 if hasattr (self .cfg , 'inference' ) and hasattr (self .cfg .inference , 'test_time_augmentation' ):
372393 tta_flip_axes_config = getattr (self .cfg .inference .test_time_augmentation , 'flip_axes' , None )
@@ -385,23 +406,21 @@ def _predict_with_tta(self, images: torch.Tensor, mask: Optional[torch.Tensor] =
385406 ensemble_result = self ._apply_tta_preprocessing (pred )
386407 else :
387408 if tta_flip_axes_config == 'all' or tta_flip_axes_config == []:
388- # "all" or []: All 8 flips (all combinations of Z, Y, X)
389- # IMPORTANT: MONAI Flip spatial_axis behavior for (B, C, D, H, W) tensors:
390- # spatial_axis=[0] flips C (channel) - WRONG for TTA!
391- # spatial_axis=[1] flips D (depth/Z) - CORRECT
392- # spatial_axis=[2] flips H (height/Y) - CORRECT
393- # spatial_axis=[3] flips W (width/X) - CORRECT
394- # Must use [1, 2, 3] for [D, H, W] flips, NOT [0, 1, 2]!
395- tta_flip_axes = [
396- [], # No flip
397- [1 ], # Flip Z (depth)
398- [2 ], # Flip Y (height)
399- [3 ], # Flip X (width)
400- [1 , 2 ], # Flip Z+Y
401- [1 , 3 ], # Flip Z+X
402- [2 , 3 ], # Flip Y+X
403- [1 , 2 , 3 ], # Flip Z+Y+X
404- ]
409+ # "all" or []: All flips (all combinations of spatial axes)
410+ # Determine spatial axes based on data dimensions
411+ if images .dim () == 5 : # 3D data: [B, C, D, H, W]
412+ spatial_axes = [1 , 2 , 3 ] # [D, H, W]
413+ elif images .dim () == 4 : # 2D data: [B, C, H, W]
414+ spatial_axes = [1 , 2 ] # [H, W]
415+ else :
416+ raise ValueError (f"Unsupported data dimensions: { images .dim ()} " )
417+
418+ # Generate all combinations of spatial axes
419+ tta_flip_axes = [[]] # No flip baseline
420+ for r in range (1 , len (spatial_axes ) + 1 ):
421+ from itertools import combinations
422+ for combo in combinations (spatial_axes , r ):
423+ tta_flip_axes .append (list (combo ))
405424 elif isinstance (tta_flip_axes_config , (list , tuple )):
406425 # Custom list: Add no-flip baseline + user-specified flips
407426 tta_flip_axes = [[]] + list (tta_flip_axes_config )
0 commit comments