Skip to content

Commit 0f8d4fe

Browse files
committed
initial prototype
1 parent 44c4994 commit 0f8d4fe

File tree

4 files changed

+298
-177
lines changed

4 files changed

+298
-177
lines changed

test.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,6 @@
1616
device="cuda",
1717
save_dir=pathlib.Path("/media/u1910100/data/overlays/test"),
1818
overwrite=True,
19+
output_type="annotationstore",
20+
class_dict={0: "nucleus"},
1921
)

tiatoolbox/models/architecture/mapde.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,16 @@
1212
import torch
1313
import torch.nn.functional as F # noqa: N812
1414
from skimage.feature import peak_local_max
15+
import dask.array as da
16+
from tiatoolbox.annotation.storage import SQLiteStore
17+
import pandas as pd
1518

1619
from tiatoolbox.models.architecture.micronet import MicroNet
17-
20+
from tiatoolbox.models.engine.nucleus_detector import (
21+
peak_detection_mapoverlap,
22+
centroids_map_to_dask_dataframe,
23+
nucleus_detection_nms,
24+
)
1825

1926
class MapDe(MicroNet):
2027
"""Initialize MapDe [1].
@@ -231,30 +238,57 @@ def forward(self: MapDe, input_tensor: torch.Tensor) -> torch.Tensor:
231238
logits, _, _, _ = super().forward(input_tensor)
232239
out = F.conv2d(logits, self.dist_filter, padding="same")
233240
return F.relu(out)
241+
242+
243+
244+
234245

235246
# skipcq: PYL-W0221 # noqa: ERA001
236-
def postproc(self: MapDe, prediction_map: np.ndarray) -> np.ndarray:
237-
"""Post-processing script for MicroNet.
247+
def postproc(self: MapDe, prediction_map: da.Array, prediction_shape: tuple, dtype: np.dtype) -> pd.DataFrame:
248+
"""Post-processing script for MapDe.
238249
239250
Performs peak detection and extracts coordinates in x, y format.
240251
241252
Args:
242-
prediction_map (ndarray):
243-
Input image of type numpy array.
253+
prediction_map (da.array):
254+
Predicted probability map (HxWx1) of the entire input image.
244255
245256
Returns:
246-
:class:`numpy.ndarray`:
247-
Pixel-wise nuclear instance segmentation
248-
prediction.
257+
detected_nuclei (pandas.DataFrame):
258+
Detected nuclei coordinates stored in a pandas DataFrame.
249259
250260
"""
251-
coordinates = peak_local_max(
252-
np.squeeze(prediction_map[0], axis=2),
261+
# coordinates = peak_local_max(
262+
# np.squeeze(prediction_map[0], axis=2),
263+
# min_distance=self.min_distance,
264+
# threshold_abs=self.threshold_abs,
265+
# exclude_border=False,
266+
# )
267+
# return np.fliplr(coordinates)
268+
269+
depth = {0: self.min_distance, 1: self.min_distance, 2: 0}
270+
scores = da.map_overlap(
271+
prediction_map,
272+
peak_detection_mapoverlap,
273+
depth=depth,
274+
boundary=0,
275+
dtype=dtype,
276+
block_info=True,
253277
min_distance=self.min_distance,
254278
threshold_abs=self.threshold_abs,
255-
exclude_border=False,
279+
depth_h=self.min_distance,
280+
depth_w=self.min_distance,
281+
calculate_probabilities=False,
256282
)
257-
return np.fliplr(coordinates)
283+
ddf = centroids_map_to_dask_dataframe(scores, x_offset=0, y_offset=0)
284+
pandas_df = ddf.compute()
285+
286+
print("Total detections before NMS:", len(pandas_df))
287+
nms_df = nucleus_detection_nms(pandas_df, radius=self.min_distance)
288+
print("Total detections after NMS:", len(nms_df))
289+
290+
return nms_df
291+
258292

259293
@staticmethod
260294
def infer_batch(

0 commit comments

Comments
 (0)