@@ -239,34 +239,44 @@ def forward(self: MapDe, input_tensor: torch.Tensor) -> torch.Tensor:
239239 def postproc (
240240 self : MapDe ,
241241 block : np .ndarray ,
242- block_info : dict ,
243- depth_h : int ,
244- depth_w : int ,
242+ block_info : dict | None = None ,
243+ depth_h : int = 0 ,
244+ depth_w : int = 0 ,
245245 ) -> np .ndarray :
246- """Runs inside Dask.da.map_overlap on a padded NumPy block: (h_pad, w_pad, C) .
246+ """ MapDe post-processing function .
247247
248- Builds a processed mask per channel, runs peak_local_max then
249- writes 1.0 at centroid pixels.
250- Keeps only centroids whose (row,col) lie in the interior window:
248+ Builds a processed mask per input channel, runs peak_local_max then
249+ writes 1.0 at peak pixels.
250+
251+ Can be called inside Dask.da.map_overlap on a padded NumPy block: (h_pad, w_pad, C)
252+ to process large prediction maps in chunks. Keeps only centroids whose (row,col)
253+ lie in the interior window:
251254 rows [depth_h : depth_h + core_h), cols [depth_w : depth_w + core_w)
252- Returns same spatial shape as input block: (h_pad, w_pad, C), float32.
255+
256+ Returns same spatial shape as the input block
253257
254258 Args:
255- block: NumPy array (H, W, C) with padded block data.
256- block_info: Dask block info dict.
257- depth_h: Halo size in pixels for height (rows).
258- depth_w: Halo size in pixels for width (cols).
259+ block: NumPy array (H, W, C).
260+ block_info: Dask block info dict. Only used when called inside dask.array.map_overlap.
261+ depth_h: Halo size in pixels for height (rows).
262+ Only used when it's called inside dask.array.map_overlap.
263+ depth_w: Halo size in pixels for width (cols).
264+ Only used when it's called inside dask.array.map_overlap.
259265
260266 Returns:
261- out: NumPy array (H, W, C) with 1 at centroids , 0 elsewhere.
267+ out: NumPy array (H, W, C) with 1.0 at peaks , 0 elsewhere.
262268 """
263269 block_height , block_width , block_channels = block .shape
264270
265271 # --- derive core (pre-overlap) size for THIS block ---
266- info = block_info [0 ]
267- locs = info ["array-location" ] # a list of (start, stop) coordinates per axis
268- core_h = int (locs [0 ][1 ] - locs [0 ][0 ]) # r1 - r0
269- core_w = int (locs [1 ][1 ] - locs [1 ][0 ])
272+ if block_info is None :
273+ core_h = block_height - 2 * depth_h
274+ core_w = block_width - 2 * depth_w
275+ else :
276+ info = block_info [0 ]
277+ locs = info ["array-location" ] # a list of (start, stop) coordinates per axis
278+ core_h = int (locs [0 ][1 ] - locs [0 ][0 ]) # r1 - r0
279+ core_w = int (locs [1 ][1 ] - locs [1 ][0 ])
270280
271281 rmin , rmax = depth_h , depth_h + core_h
272282 cmin , cmax = depth_w , depth_w + core_w
0 commit comments