1010from __future__ import annotations
1111
1212from collections import OrderedDict
13+ from typing import TYPE_CHECKING
14+
15+ if TYPE_CHECKING :
16+ import numpy as np
1317
14- import numpy as np
1518import torch
16- from skimage .feature import peak_local_max
1719from torch import nn
1820
21+ from tiatoolbox .models .architecture .utils import peak_detection_da_map_overlap
1922from tiatoolbox .models .models_abc import ModelABC
2023
2124
@@ -92,7 +95,7 @@ def __init__(
9295 min_distance : int = 6 ,
9396 threshold_abs : float = 0.20 ,
9497 postproc_tile_shape : tuple [int , int ] = (2048 , 2048 ),
95- output_class_dict : dict [int , str ] | None = None ,
98+ class_dict : dict [int , str ] | None = None ,
9699 ) -> None :
97100 """Initialize :class:`SCCNN`."""
98101 super ().__init__ ()
@@ -102,7 +105,7 @@ def __init__(
102105 self .out_height = out_height
103106 self .out_width = out_width
104107 self .postproc_tile_shape = postproc_tile_shape
105- self .output_class_dict = output_class_dict
108+ self .output_class_dict = class_dict
106109
107110 # Create mesh grid and convert to 3D vector
108111 x , y = torch .meshgrid (
@@ -341,11 +344,6 @@ def postproc(
341344 Builds a processed mask per input channel, runs peak_local_max then
342345 writes 1.0 at peak pixels.
343346
344- Can be called inside Dask.da.map_overlap on a padded NumPy block:
345- (h_pad, w_pad, C) to process large prediction maps in chunks.
346- Keeps only centroids whose (row,col) lie in the interior window:
347- rows [depth_h : depth_h + core_h), cols [depth_w : depth_w + core_w)
348-
349347 Returns same spatial shape as the input block
350348
351349 Args:
@@ -360,40 +358,14 @@ def postproc(
360358 Returns:
361359 out: NumPy array (H, W, C) with 1.0 at peaks, 0 elsewhere.
362360 """
363- block_height , block_width , block_channels = block .shape
364-
365- # --- derive core (pre-overlap) size for THIS block ---
366- if block_info is None :
367- core_h = block_height - 2 * depth_h
368- core_w = block_width - 2 * depth_w
369- else :
370- info = block_info [0 ]
371- locs = info [
372- "array-location"
373- ] # a list of (start, stop) coordinates per axis
374- core_h = int (locs [0 ][1 ] - locs [0 ][0 ]) # r1 - r0
375- core_w = int (locs [1 ][1 ] - locs [1 ][0 ])
376-
377- rmin , rmax = depth_h , depth_h + core_h
378- cmin , cmax = depth_w , depth_w + core_w
379-
380- out = np .zeros ((block_height , block_width , block_channels ), dtype = np .float32 )
381-
382- for ch in range (block_channels ):
383- img = np .asarray (block [..., ch ]) # NumPy 2D view
384-
385- coords = peak_local_max (
386- img ,
387- min_distance = self .min_distance ,
388- threshold_abs = self .threshold_abs ,
389- exclude_border = False ,
390- )
391-
392- for r , c in coords :
393- if (rmin <= r < rmax ) and (cmin <= c < cmax ):
394- out [r , c , ch ] = 1.0
395-
396- return out
361+ return peak_detection_da_map_overlap (
362+ block ,
363+ min_distance = self .min_distance ,
364+ threshold_abs = self .threshold_abs ,
365+ block_info = block_info ,
366+ depth_h = depth_h ,
367+ depth_w = depth_w ,
368+ )
397369
398370 @staticmethod
399371 def infer_batch (
0 commit comments