Skip to content

Commit 7e973c1

Browse files
committed
fix: fix bug in result segmentation smoothing after vectorizing
1 parent 6ccba55 commit 7e973c1

File tree

1 file changed

+78
-13
lines changed

1 file changed

+78
-13
lines changed

cellseg_models_pytorch/inference/post_processor.py

Lines changed: 78 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import gc
22
from functools import partial
33
from pathlib import Path
4-
from typing import Dict, List, Tuple, Union
4+
from typing import Callable, Dict, List, Tuple, Union
55

66
import numpy as np
77
import torch
@@ -52,13 +52,16 @@ def postproc_tissuemap(
5252
save_path: Union[Path, str] = None,
5353
coords: Tuple[int, int, int, int] = None,
5454
class_dict: Dict[str, int] = None,
55+
smooth_func: Callable = None,
5556
) -> np.ndarray:
5657
"""Run tissue map post-processing."""
5758
tissue_map = remove_debris_semantic(tissue_map, min_size=5000)
5859
tissue_map = fill_holes_semantic(tissue_map, min_size=5000).astype("i4")
5960

6061
if save_path is not None:
61-
self._save_sem2vector(save_path, tissue_map, coords, class_dict)
62+
self._save_sem2vector(
63+
save_path, tissue_map, coords, class_dict, smooth_func=smooth_func
64+
)
6265
gc.collect()
6366
else:
6467
gc.collect()
@@ -72,13 +75,21 @@ def postproc_inst(
7275
save_path: Union[Path, str] = None,
7376
coords: Tuple[int, int, int, int] = None,
7477
class_dict: Dict[str, int] = None,
78+
smooth_func: Callable = gaussian_smooth,
7579
) -> Tuple[np.ndarray, np.ndarray]:
7680
"""Run instace map post-processing."""
7781
inst_map = self.postproc_func(inst_map, aux_map).astype("i4")
7882
type_map = majority_vote_sequential(type_map, inst_map).astype("i4")
7983

8084
if save_path is not None:
81-
self._save_inst2vector(save_path, inst_map, type_map, coords, class_dict)
85+
self._save_inst2vector(
86+
save_path,
87+
inst_map,
88+
type_map,
89+
coords,
90+
class_dict,
91+
smooth_func=smooth_func,
92+
)
8293
gc.collect()
8394
else:
8495
gc.collect()
@@ -97,6 +108,9 @@ def postproc_parallel(
97108
class_dict_nuc: Dict[int, str] = None,
98109
class_dict_cyto: Dict[int, str] = None,
99110
class_dict_tissue: Dict[int, str] = None,
111+
nuc_smooth_func: Callable = gaussian_smooth,
112+
cyto_smooth_func: Callable = gaussian_smooth,
113+
tissue_smooth_func: Callable = None,
100114
) -> Dict[str, List[np.ndarray]]:
101115
"""Post-process the masks in parallel using multiprocessing."""
102116
# set up input args for
@@ -135,7 +149,11 @@ def postproc_parallel(
135149
if soft_masks["nuc"] is not None:
136150
nuc_results = self._pool_map(
137151
pool,
138-
partial(self.postproc_inst, class_dict=class_dict_nuc),
152+
partial(
153+
self.postproc_inst,
154+
class_dict=class_dict_nuc,
155+
smooth_func=nuc_smooth_func,
156+
),
139157
list(
140158
zip(
141159
nuc_inst_maps,
@@ -151,7 +169,11 @@ def postproc_parallel(
151169
if soft_masks["cyto"] is not None:
152170
cyto_results = self._pool_map(
153171
pool,
154-
partial(self.postproc_inst, class_dict=class_dict_cyto),
172+
partial(
173+
self.postproc_inst,
174+
class_dict=class_dict_cyto,
175+
smooth_func=cyto_smooth_func,
176+
),
155177
list(
156178
zip(
157179
cyto_inst_maps,
@@ -167,7 +189,11 @@ def postproc_parallel(
167189
if soft_masks["tissue"] is not None:
168190
tissue_results = self._pool_map(
169191
pool,
170-
partial(self.postproc_tissuemap, class_dict=class_dict_tissue),
192+
partial(
193+
self.postproc_tissuemap,
194+
class_dict=class_dict_tissue,
195+
smooth_func=tissue_smooth_func,
196+
),
171197
list(zip(tissue_maps, save_paths_tissue, coords)),
172198
progress_bar=progress_bar,
173199
)
@@ -186,6 +212,9 @@ def postproc_parallel_async(
186212
class_dict_nuc: Dict[int, str] = None,
187213
class_dict_cyto: Dict[int, str] = None,
188214
class_dict_tissue: Dict[int, str] = None,
215+
nuc_smooth_func: Callable = gaussian_smooth,
216+
cyto_smooth_func: Callable = gaussian_smooth,
217+
tissue_smooth_func: Callable = None,
189218
) -> Dict[str, List[np.ndarray]]:
190219
"""Post-process the masks in parallel using async."""
191220
# set up input args for
@@ -225,7 +254,11 @@ def postproc_parallel_async(
225254
if soft_masks["nuc"] is not None:
226255
nuc_results = self._pool_apply_async(
227256
pool,
228-
partial(self.postproc_inst, class_dict=class_dict_nuc),
257+
partial(
258+
self.postproc_inst,
259+
class_dict=class_dict_nuc,
260+
smooth_func=nuc_smooth_func,
261+
),
229262
list(
230263
zip(
231264
nuc_inst_maps,
@@ -240,7 +273,11 @@ def postproc_parallel_async(
240273
if soft_masks["cyto"] is not None:
241274
cyto_results = self._pool_apply_async(
242275
pool,
243-
partial(self.postproc_inst, class_dict=class_dict_cyto),
276+
partial(
277+
self.postproc_inst,
278+
class_dict=class_dict_cyto,
279+
smooth_func=cyto_smooth_func,
280+
),
244281
list(
245282
zip(
246283
cyto_inst_maps,
@@ -255,7 +292,11 @@ def postproc_parallel_async(
255292
if soft_masks["tissue"] is not None:
256293
tissue_results = self._pool_apply_async(
257294
pool,
258-
partial(self.postproc_tissuemap, class_dict=class_dict_tissue),
295+
partial(
296+
self.postproc_tissuemap,
297+
class_dict=class_dict_tissue,
298+
smooth_func=tissue_smooth_func,
299+
),
259300
list(zip(tissue_maps, save_paths_tissue, coords)),
260301
)
261302

@@ -276,6 +317,9 @@ def postproc_serial(
276317
class_dict_nuc: Dict[int, str] = None,
277318
class_dict_cyto: Dict[int, str] = None,
278319
class_dict_tissue: Dict[int, str] = None,
320+
nuc_smooth_func: Callable = gaussian_smooth,
321+
cyto_smooth_func: Callable = gaussian_smooth,
322+
tissue_smooth_func: Callable = None,
279323
) -> Dict[str, List[np.ndarray]]:
280324
"""Run post-processing sequentially."""
281325
nuc_inst_maps, nuc_aux_maps, nuc_type_maps = self._prepare_inst_maps(
@@ -306,7 +350,13 @@ def postproc_serial(
306350
if soft_masks["nuc"] is not None:
307351
nuc_results = [
308352
self.postproc_inst(
309-
inst_map, aux_map, type_map, save_path, coord, class_dict_nuc
353+
inst_map,
354+
aux_map,
355+
type_map,
356+
save_path,
357+
coord,
358+
class_dict_nuc,
359+
smooth_func=nuc_smooth_func,
310360
)
311361
for inst_map, aux_map, type_map, save_path, coord in zip(
312362
nuc_inst_maps,
@@ -320,7 +370,13 @@ def postproc_serial(
320370
if soft_masks["cyto"] is not None:
321371
cyto_results = [
322372
self.postproc_inst(
323-
inst_map, aux_map, type_map, save_path, coord, class_dict_cyto
373+
inst_map,
374+
aux_map,
375+
type_map,
376+
save_path,
377+
coord,
378+
class_dict_cyto,
379+
smooth_func=cyto_smooth_func,
324380
)
325381
for inst_map, aux_map, type_map, save_path, coord in zip(
326382
cyto_inst_maps,
@@ -333,7 +389,13 @@ def postproc_serial(
333389

334390
if soft_masks["tissue"] is not None:
335391
tissue_results = [
336-
self.postproc_tissuemap(tissue_map, save_path, coord, class_dict_tissue)
392+
self.postproc_tissuemap(
393+
tissue_map,
394+
save_path,
395+
coord,
396+
class_dict_tissue,
397+
smooth_func=tissue_smooth_func,
398+
)
337399
for tissue_map, save_path, coord in zip(
338400
tissue_maps, save_paths_tissue, coords
339401
)
@@ -398,6 +460,7 @@ def _save_inst2vector(
398460
class_dict: dict = None,
399461
compute_centroids: bool = False,
400462
compute_bboxes: bool = False,
463+
smooth_func: Callable = gaussian_smooth,
401464
) -> None:
402465
save_path = Path(save_path)
403466

@@ -410,7 +473,7 @@ def _save_inst2vector(
410473
xoff=xoff,
411474
yoff=yoff,
412475
class_dict=class_dict,
413-
smooth_func=gaussian_smooth,
476+
smooth_func=smooth_func,
414477
)
415478

416479
if compute_centroids:
@@ -426,6 +489,7 @@ def _save_sem2vector(
426489
sem_map: np.ndarray,
427490
coords: List[Tuple[int, int, int, int]] = None,
428491
class_dict: dict = None,
492+
smooth_func: Callable = None,
429493
) -> None:
430494
save_path = Path(save_path)
431495

@@ -437,6 +501,7 @@ def _save_sem2vector(
437501
xoff=xoff,
438502
yoff=yoff,
439503
class_dict=class_dict,
504+
smooth_func=smooth_func,
440505
)
441506

442507
FileHandler.gdf_to_file(sem_gdf, save_path, silence_warnings=True)

0 commit comments

Comments
 (0)