@@ -76,6 +76,7 @@ def predict(
7676 self ,
7777 x : Union [torch .Tensor , np .ndarray , Image ],
7878 apply_boundary_weight : bool = False ,
79+ save_intermediate : bool = False ,
7980 ) -> Dict [str , torch .Tensor ]:
8081 """Run the input through the model.
8182
@@ -88,6 +89,10 @@ def predict(
8889 apply_boundary_weight (bool, default=True):
8990 Whether to apply boundary weights to mitigate boundary artefacts
9091 in aux predictions.
92+ save_intermediate (bool, default=False):
93+ Whether to save intermediate results (logits). If True, the method
94+ returns a tuple (final predictions, intermediate results), where the
95+ intermediate results are the raw model outputs before argmax.
9196
9297 Returns:
9398 Dict[str, torch.Tensor]:
@@ -117,15 +122,22 @@ def predict(
117122 .unsqueeze (0 )
118123 )
119124
125+ intermediate = None
120126 with torch .no_grad ():
121127 if self .mixed_precision :
122128 with torch .autocast (self .device .type , dtype = torch .float16 ):
123129 probs = self ._predict (x )
130+ if save_intermediate :
131+ intermediate = probs
124132 probs = self ._argmax (probs )
125133 else :
126134 probs = self ._predict (x )
135+ if save_intermediate :
136+ intermediate = probs
127137 probs = self ._argmax (probs )
128138
139+ if save_intermediate :
140+ return probs , intermediate
129141 return probs
130142
131143 def _to_tensor (self , x : Union [np .ndarray , Image ]) -> torch .Tensor :
@@ -265,6 +277,7 @@ def predict_sliding_win(
265277 stride : int ,
266278 padding : int = 20 ,
267279 apply_boundary_weight : bool = True ,
280+ save_intermediate : bool = False ,
268281 ) -> Dict [str , torch .Tensor ]:
269282 """Run the input through the model.
270283
@@ -283,11 +296,14 @@ def predict_sliding_win(
283296 apply_boundary_weight (bool, default=True):
284297 Whether to apply boundary weights to mitigate boundary artefacts
285298 in aux predictions.
299+ save_intermediate (bool, default=False):
300+ Whether to save intermediate results (logits). If True, the method
301+ returns a tuple (final predictions, intermediate results), where the
302+ intermediate results are the raw model outputs before argmax.
286303
287304 Returns:
288305 Dict[str, torch.Tensor]:
289- Dictionary containing the model predictions (probabilities).
290- Shapes: (B, C, H, W).
306+ Dictionary containing the model predictions. Shapes: (B, C, H, W).
291307 """
292308 # check if the input is a tensor
293309 if not isinstance (x , torch .Tensor ):
@@ -311,15 +327,23 @@ def predict_sliding_win(
311327 .unsqueeze (0 )
312328 )
313329
330+ intermediate = None
314331 with torch .no_grad ():
315332 if self .mixed_precision :
316333 with torch .autocast (self .device .type , dtype = torch .float16 ):
317334 probs = self ._predict_sliding_win (x , window_size , stride , padding )
335+ if save_intermediate :
336+ intermediate = probs
318337 probs = self ._argmax (probs )
319338 else :
320339 probs = self ._predict_sliding_win (x , window_size , stride , padding )
340+ if save_intermediate :
341+ intermediate = probs
321342 probs = self ._argmax (probs )
322343
344+ if save_intermediate :
345+ return probs , intermediate
346+
323347 return probs
324348
325349 @staticmethod
0 commit comments