Skip to content

Commit f994c2e

Browse files
committed
feat: add save_intermediate param to predictor to return output logits
1 parent e13d78e commit f994c2e

File tree

2 files changed

+27
-6
lines changed

2 files changed

+27
-6
lines changed

cellseg_models_pytorch/inference/predictor.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def predict(
7676
self,
7777
x: Union[torch.Tensor, np.ndarray, Image],
7878
apply_boundary_weight: bool = False,
79+
save_intermediate: bool = False,
7980
) -> Dict[str, torch.Tensor]:
8081
"""Run the input through the model.
8182
@@ -88,6 +89,10 @@ def predict(
8889
apply_boundary_weight (bool, default=True):
8990
Whether to apply boundary weights to mitigate boundary artefacts
9091
in aux predictions.
92+
save_intermediate (bool, default=False):
93+
Whether to save intermediate results (logits). If True, the method
94+
returns a tuple (final predictions, intermediate results), where the
95+
intermediate results are the raw model outputs before argmax.
9196
9297
Returns:
9398
Dict[str, torch.Tensor]:
@@ -117,15 +122,22 @@ def predict(
117122
.unsqueeze(0)
118123
)
119124

125+
intermediate = None
120126
with torch.no_grad():
121127
if self.mixed_precision:
122128
with torch.autocast(self.device.type, dtype=torch.float16):
123129
probs = self._predict(x)
130+
if save_intermediate:
131+
intermediate = probs
124132
probs = self._argmax(probs)
125133
else:
126134
probs = self._predict(x)
135+
if save_intermediate:
136+
intermediate = probs
127137
probs = self._argmax(probs)
128138

139+
if save_intermediate:
140+
return probs, intermediate
129141
return probs
130142

131143
def _to_tensor(self, x: Union[np.ndarray, Image]) -> torch.Tensor:
@@ -265,6 +277,7 @@ def predict_sliding_win(
265277
stride: int,
266278
padding: int = 20,
267279
apply_boundary_weight: bool = True,
280+
save_intermediate: bool = False,
268281
) -> Dict[str, torch.Tensor]:
269282
"""Run the input through the model.
270283
@@ -283,11 +296,14 @@ def predict_sliding_win(
283296
apply_boundary_weight (bool, default=True):
284297
Whether to apply boundary weights to mitigate boundary artefacts
285298
in aux predictions.
299+
save_intermediate (bool, default=False):
300+
Whether to save intermediate results (logits). If True, the method
301+
returns a tuple (final predictions, intermediate results), where the
302+
intermediate results are the raw model outputs before argmax.
286303
287304
Returns:
288305
Dict[str, torch.Tensor]:
289-
Dictionary containing the model predictions (probabilities).
290-
Shapes: (B, C, H, W).
306+
Dictionary containing the model predictions. Shapes: (B, C, H, W).
291307
"""
292308
# check if the input is a tensor
293309
if not isinstance(x, torch.Tensor):
@@ -311,15 +327,23 @@ def predict_sliding_win(
311327
.unsqueeze(0)
312328
)
313329

330+
intermediate = None
314331
with torch.no_grad():
315332
if self.mixed_precision:
316333
with torch.autocast(self.device.type, dtype=torch.float16):
317334
probs = self._predict_sliding_win(x, window_size, stride, padding)
335+
if save_intermediate:
336+
intermediate = probs
318337
probs = self._argmax(probs)
319338
else:
320339
probs = self._predict_sliding_win(x, window_size, stride, padding)
340+
if save_intermediate:
341+
intermediate = probs
321342
probs = self._argmax(probs)
322343

344+
if save_intermediate:
345+
return probs, intermediate
346+
323347
return probs
324348

325349
@staticmethod

cellseg_models_pytorch/wsi/cucim_reader.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,7 @@
1313
HAS_CUCIM = False
1414

1515

16-
CUCIM_READABLE_FORMATS = (
17-
".svs",
18-
".tiff",
19-
)
16+
CUCIM_READABLE_FORMATS = (".svs", ".tiff", ".tif")
2017

2118

2219
class CucimReader(SlideReaderBackend):

0 commit comments

Comments
 (0)