11import gc
22from functools import partial
33from pathlib import Path
4- from typing import Dict , List , Tuple , Union
4+ from typing import Callable , Dict , List , Tuple , Union
55
66import numpy as np
77import torch
@@ -52,13 +52,16 @@ def postproc_tissuemap(
5252 save_path : Union [Path , str ] = None ,
5353 coords : Tuple [int , int , int , int ] = None ,
5454 class_dict : Dict [str , int ] = None ,
55+ smooth_func : Callable = None ,
5556 ) -> np .ndarray :
5657 """Run tissue map post-processing."""
5758 tissue_map = remove_debris_semantic (tissue_map , min_size = 5000 )
5859 tissue_map = fill_holes_semantic (tissue_map , min_size = 5000 ).astype ("i4" )
5960
6061 if save_path is not None :
61- self ._save_sem2vector (save_path , tissue_map , coords , class_dict )
62+ self ._save_sem2vector (
63+ save_path , tissue_map , coords , class_dict , smooth_func = smooth_func
64+ )
6265 gc .collect ()
6366 else :
6467 gc .collect ()
@@ -72,13 +75,21 @@ def postproc_inst(
7275 save_path : Union [Path , str ] = None ,
7376 coords : Tuple [int , int , int , int ] = None ,
7477 class_dict : Dict [str , int ] = None ,
78+ smooth_func : Callable = gaussian_smooth ,
7579 ) -> Tuple [np .ndarray , np .ndarray ]:
7680 """Run instace map post-processing."""
7781 inst_map = self .postproc_func (inst_map , aux_map ).astype ("i4" )
7882 type_map = majority_vote_sequential (type_map , inst_map ).astype ("i4" )
7983
8084 if save_path is not None :
81- self ._save_inst2vector (save_path , inst_map , type_map , coords , class_dict )
85+ self ._save_inst2vector (
86+ save_path ,
87+ inst_map ,
88+ type_map ,
89+ coords ,
90+ class_dict ,
91+ smooth_func = smooth_func ,
92+ )
8293 gc .collect ()
8394 else :
8495 gc .collect ()
@@ -97,6 +108,9 @@ def postproc_parallel(
97108 class_dict_nuc : Dict [int , str ] = None ,
98109 class_dict_cyto : Dict [int , str ] = None ,
99110 class_dict_tissue : Dict [int , str ] = None ,
111+ nuc_smooth_func : Callable = gaussian_smooth ,
112+ cyto_smooth_func : Callable = gaussian_smooth ,
113+ tissue_smooth_func : Callable = None ,
100114 ) -> Dict [str , List [np .ndarray ]]:
101115 """Post-process the masks in parallel using multiprocessing."""
102116 # set up input args for
@@ -135,7 +149,11 @@ def postproc_parallel(
135149 if soft_masks ["nuc" ] is not None :
136150 nuc_results = self ._pool_map (
137151 pool ,
138- partial (self .postproc_inst , class_dict = class_dict_nuc ),
152+ partial (
153+ self .postproc_inst ,
154+ class_dict = class_dict_nuc ,
155+ smooth_func = nuc_smooth_func ,
156+ ),
139157 list (
140158 zip (
141159 nuc_inst_maps ,
@@ -151,7 +169,11 @@ def postproc_parallel(
151169 if soft_masks ["cyto" ] is not None :
152170 cyto_results = self ._pool_map (
153171 pool ,
154- partial (self .postproc_inst , class_dict = class_dict_cyto ),
172+ partial (
173+ self .postproc_inst ,
174+ class_dict = class_dict_cyto ,
175+ smooth_func = cyto_smooth_func ,
176+ ),
155177 list (
156178 zip (
157179 cyto_inst_maps ,
@@ -167,7 +189,11 @@ def postproc_parallel(
167189 if soft_masks ["tissue" ] is not None :
168190 tissue_results = self ._pool_map (
169191 pool ,
170- partial (self .postproc_tissuemap , class_dict = class_dict_tissue ),
192+ partial (
193+ self .postproc_tissuemap ,
194+ class_dict = class_dict_tissue ,
195+ smooth_func = tissue_smooth_func ,
196+ ),
171197 list (zip (tissue_maps , save_paths_tissue , coords )),
172198 progress_bar = progress_bar ,
173199 )
@@ -186,6 +212,9 @@ def postproc_parallel_async(
186212 class_dict_nuc : Dict [int , str ] = None ,
187213 class_dict_cyto : Dict [int , str ] = None ,
188214 class_dict_tissue : Dict [int , str ] = None ,
215+ nuc_smooth_func : Callable = gaussian_smooth ,
216+ cyto_smooth_func : Callable = gaussian_smooth ,
217+ tissue_smooth_func : Callable = None ,
189218 ) -> Dict [str , List [np .ndarray ]]:
190219 """Post-process the masks in parallel using async."""
191220 # set up input args for
@@ -225,7 +254,11 @@ def postproc_parallel_async(
225254 if soft_masks ["nuc" ] is not None :
226255 nuc_results = self ._pool_apply_async (
227256 pool ,
228- partial (self .postproc_inst , class_dict = class_dict_nuc ),
257+ partial (
258+ self .postproc_inst ,
259+ class_dict = class_dict_nuc ,
260+ smooth_func = nuc_smooth_func ,
261+ ),
229262 list (
230263 zip (
231264 nuc_inst_maps ,
@@ -240,7 +273,11 @@ def postproc_parallel_async(
240273 if soft_masks ["cyto" ] is not None :
241274 cyto_results = self ._pool_apply_async (
242275 pool ,
243- partial (self .postproc_inst , class_dict = class_dict_cyto ),
276+ partial (
277+ self .postproc_inst ,
278+ class_dict = class_dict_cyto ,
279+ smooth_func = cyto_smooth_func ,
280+ ),
244281 list (
245282 zip (
246283 cyto_inst_maps ,
@@ -255,7 +292,11 @@ def postproc_parallel_async(
255292 if soft_masks ["tissue" ] is not None :
256293 tissue_results = self ._pool_apply_async (
257294 pool ,
258- partial (self .postproc_tissuemap , class_dict = class_dict_tissue ),
295+ partial (
296+ self .postproc_tissuemap ,
297+ class_dict = class_dict_tissue ,
298+ smooth_func = tissue_smooth_func ,
299+ ),
259300 list (zip (tissue_maps , save_paths_tissue , coords )),
260301 )
261302
@@ -276,6 +317,9 @@ def postproc_serial(
276317 class_dict_nuc : Dict [int , str ] = None ,
277318 class_dict_cyto : Dict [int , str ] = None ,
278319 class_dict_tissue : Dict [int , str ] = None ,
320+ nuc_smooth_func : Callable = gaussian_smooth ,
321+ cyto_smooth_func : Callable = gaussian_smooth ,
322+ tissue_smooth_func : Callable = None ,
279323 ) -> Dict [str , List [np .ndarray ]]:
280324 """Run post-processing sequentially."""
281325 nuc_inst_maps , nuc_aux_maps , nuc_type_maps = self ._prepare_inst_maps (
@@ -306,7 +350,13 @@ def postproc_serial(
306350 if soft_masks ["nuc" ] is not None :
307351 nuc_results = [
308352 self .postproc_inst (
309- inst_map , aux_map , type_map , save_path , coord , class_dict_nuc
353+ inst_map ,
354+ aux_map ,
355+ type_map ,
356+ save_path ,
357+ coord ,
358+ class_dict_nuc ,
359+ smooth_func = nuc_smooth_func ,
310360 )
311361 for inst_map , aux_map , type_map , save_path , coord in zip (
312362 nuc_inst_maps ,
@@ -320,7 +370,13 @@ def postproc_serial(
320370 if soft_masks ["cyto" ] is not None :
321371 cyto_results = [
322372 self .postproc_inst (
323- inst_map , aux_map , type_map , save_path , coord , class_dict_cyto
373+ inst_map ,
374+ aux_map ,
375+ type_map ,
376+ save_path ,
377+ coord ,
378+ class_dict_cyto ,
379+ smooth_func = cyto_smooth_func ,
324380 )
325381 for inst_map , aux_map , type_map , save_path , coord in zip (
326382 cyto_inst_maps ,
@@ -333,7 +389,13 @@ def postproc_serial(
333389
334390 if soft_masks ["tissue" ] is not None :
335391 tissue_results = [
336- self .postproc_tissuemap (tissue_map , save_path , coord , class_dict_tissue )
392+ self .postproc_tissuemap (
393+ tissue_map ,
394+ save_path ,
395+ coord ,
396+ class_dict_tissue ,
397+ smooth_func = tissue_smooth_func ,
398+ )
337399 for tissue_map , save_path , coord in zip (
338400 tissue_maps , save_paths_tissue , coords
339401 )
@@ -398,6 +460,7 @@ def _save_inst2vector(
398460 class_dict : dict = None ,
399461 compute_centroids : bool = False ,
400462 compute_bboxes : bool = False ,
463+ smooth_func : Callable = gaussian_smooth ,
401464 ) -> None :
402465 save_path = Path (save_path )
403466
@@ -410,7 +473,7 @@ def _save_inst2vector(
410473 xoff = xoff ,
411474 yoff = yoff ,
412475 class_dict = class_dict ,
413- smooth_func = gaussian_smooth ,
476+ smooth_func = smooth_func ,
414477 )
415478
416479 if compute_centroids :
@@ -426,6 +489,7 @@ def _save_sem2vector(
426489 sem_map : np .ndarray ,
427490 coords : List [Tuple [int , int , int , int ]] = None ,
428491 class_dict : dict = None ,
492+ smooth_func : Callable = None ,
429493 ) -> None :
430494 save_path = Path (save_path )
431495
@@ -437,6 +501,7 @@ def _save_sem2vector(
437501 xoff = xoff ,
438502 yoff = yoff ,
439503 class_dict = class_dict ,
504+ smooth_func = smooth_func ,
440505 )
441506
442507 FileHandler .gdf_to_file (sem_gdf , save_path , silence_warnings = True )
0 commit comments