|
12 | 12 | import torch |
13 | 13 | import torch.nn.functional as F # noqa: N812 |
14 | 14 | from skimage.feature import peak_local_max |
| 15 | +import dask.array as da |
| 16 | +from tiatoolbox.annotation.storage import SQLiteStore |
| 17 | +import pandas as pd |
15 | 18 |
|
16 | 19 | from tiatoolbox.models.architecture.micronet import MicroNet |
17 | | - |
| 20 | +from tiatoolbox.models.engine.nucleus_detector import ( |
| 21 | + peak_detection_mapoverlap, |
| 22 | + centroids_map_to_dask_dataframe, |
| 23 | + nucleus_detection_nms, |
| 24 | +) |
18 | 25 |
|
19 | 26 | class MapDe(MicroNet): |
20 | 27 | """Initialize MapDe [1]. |
@@ -231,30 +238,57 @@ def forward(self: MapDe, input_tensor: torch.Tensor) -> torch.Tensor: |
231 | 238 | logits, _, _, _ = super().forward(input_tensor) |
232 | 239 | out = F.conv2d(logits, self.dist_filter, padding="same") |
233 | 240 | return F.relu(out) |
| 241 | + |
| 242 | + |
| 243 | + |
| 244 | + |
234 | 245 |
|
235 | 246 | # skipcq: PYL-W0221 # noqa: ERA001 |
236 | | - def postproc(self: MapDe, prediction_map: np.ndarray) -> np.ndarray: |
237 | | - """Post-processing script for MicroNet. |
| 247 | + def postproc(self: MapDe, prediction_map: da.Array, prediction_shape: tuple, dtype: np.dtype) -> pd.DataFrame: |
| 248 | + """Post-processing script for MapDe. |
238 | 249 |
|
239 | 250 | Performs peak detection and extracts coordinates in x, y format. |
240 | 251 |
|
241 | 252 | Args: |
242 | | - prediction_map (ndarray): |
243 | | - Input image of type numpy array. |
| 253 | + prediction_map (da.array): |
| 254 | + Predicted probability map (HxWx1) of the entire input image. |
244 | 255 |
|
245 | 256 | Returns: |
246 | | - :class:`numpy.ndarray`: |
247 | | - Pixel-wise nuclear instance segmentation |
248 | | - prediction. |
| 257 | + detected_nuclei (pandas.DataFrame): |
| 258 | + Detected nuclei coordinates stored in a pandas DataFrame. |
249 | 259 |
|
250 | 260 | """ |
251 | | - coordinates = peak_local_max( |
252 | | - np.squeeze(prediction_map[0], axis=2), |
| 261 | + # coordinates = peak_local_max( |
| 262 | + # np.squeeze(prediction_map[0], axis=2), |
| 263 | + # min_distance=self.min_distance, |
| 264 | + # threshold_abs=self.threshold_abs, |
| 265 | + # exclude_border=False, |
| 266 | + # ) |
| 267 | + # return np.fliplr(coordinates) |
| 268 | + |
| 269 | + depth = {0: self.min_distance, 1: self.min_distance, 2: 0} |
| 270 | + scores = da.map_overlap( |
| 271 | + prediction_map, |
| 272 | + peak_detection_mapoverlap, |
| 273 | + depth=depth, |
| 274 | + boundary=0, |
| 275 | + dtype=dtype, |
| 276 | + block_info=True, |
253 | 277 | min_distance=self.min_distance, |
254 | 278 | threshold_abs=self.threshold_abs, |
255 | | - exclude_border=False, |
| 279 | + depth_h=self.min_distance, |
| 280 | + depth_w=self.min_distance, |
| 281 | + calculate_probabilities=False, |
256 | 282 | ) |
257 | | - return np.fliplr(coordinates) |
| 283 | + ddf = centroids_map_to_dask_dataframe(scores, x_offset=0, y_offset=0) |
| 284 | + pandas_df = ddf.compute() |
| 285 | + |
| 286 | + print("Total detections before NMS:", len(pandas_df)) |
| 287 | + nms_df = nucleus_detection_nms(pandas_df, radius=self.min_distance) |
| 288 | + print("Total detections after NMS:", len(nms_df)) |
| 289 | + |
| 290 | + return nms_df |
| 291 | + |
258 | 292 |
|
259 | 293 | @staticmethod |
260 | 294 | def infer_batch( |
|
0 commit comments