Skip to content

Commit d42b78a

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 0f8d4fe commit d42b78a

File tree

3 files changed

+64
-55
lines changed

3 files changed

+64
-55
lines changed

tiatoolbox/models/architecture/mapde.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,21 +8,20 @@
88

99
from __future__ import annotations
1010

11+
import dask.array as da
1112
import numpy as np
13+
import pandas as pd
1214
import torch
1315
import torch.nn.functional as F # noqa: N812
14-
from skimage.feature import peak_local_max
15-
import dask.array as da
16-
from tiatoolbox.annotation.storage import SQLiteStore
17-
import pandas as pd
1816

1917
from tiatoolbox.models.architecture.micronet import MicroNet
2018
from tiatoolbox.models.engine.nucleus_detector import (
21-
peak_detection_mapoverlap,
2219
centroids_map_to_dask_dataframe,
2320
nucleus_detection_nms,
21+
peak_detection_mapoverlap,
2422
)
2523

24+
2625
class MapDe(MicroNet):
2726
"""Initialize MapDe [1].
2827
@@ -238,13 +237,11 @@ def forward(self: MapDe, input_tensor: torch.Tensor) -> torch.Tensor:
238237
logits, _, _, _ = super().forward(input_tensor)
239238
out = F.conv2d(logits, self.dist_filter, padding="same")
240239
return F.relu(out)
241-
242-
243-
244-
245240

246241
# skipcq: PYL-W0221 # noqa: ERA001
247-
def postproc(self: MapDe, prediction_map: da.Array, prediction_shape: tuple, dtype: np.dtype) -> pd.DataFrame:
242+
def postproc(
243+
self: MapDe, prediction_map: da.Array, prediction_shape: tuple, dtype: np.dtype
244+
) -> pd.DataFrame:
248245
"""Post-processing script for MapDe.
249246
250247
Performs peak detection and extracts coordinates in x, y format.
@@ -289,7 +286,6 @@ def postproc(self: MapDe, prediction_map: da.Array, prediction_shape: tuple, dty
289286

290287
return nms_df
291288

292-
293289
@staticmethod
294290
def infer_batch(
295291
model: torch.nn.Module,

tiatoolbox/models/engine/nucleus_detector.py

Lines changed: 42 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,53 +2,53 @@
22

33
from __future__ import annotations
44

5-
import os
6-
import sys
75
from pathlib import Path
86
from typing import TYPE_CHECKING, Unpack
97

108
import dask.array as da
119
import dask.dataframe as dd
1210
import numpy as np
1311
import pandas as pd
14-
from shapely.geometry import Point
1512
from skimage.feature import peak_local_max
1613
from skimage.measure import label, regionprops
1714

18-
from tiatoolbox.models.engine.io_config import IOSegmentorConfig
15+
from tiatoolbox import logger
16+
from tiatoolbox.annotation import AnnotationStore
1917
from tiatoolbox.models.engine.semantic_segmentor import (
2018
SemanticSegmentor,
2119
SemanticSegmentorRunParams,
2220
)
2321
from tiatoolbox.models.models_abc import ModelABC
24-
from tiatoolbox.annotation import Annotation, SQLiteStore, AnnotationStore
2522
from tiatoolbox.utils.misc import df_to_store_nucleus_detector
26-
from tiatoolbox import logger
2723

2824
if TYPE_CHECKING: # pragma: no cover
29-
import os
30-
from tiatoolbox.models.engine.io_config import IOSegmentorConfig
3125
from tiatoolbox.models.models_abc import ModelABC
32-
from tiatoolbox.wsicore import WSIReader
3326

3427

3528
def probability_to_peak_map(
36-
img2d: np.ndarray, min_distance: int, threshold_abs: float, threshold_rel: float = 0.0
29+
img2d: np.ndarray,
30+
min_distance: int,
31+
threshold_abs: float,
32+
threshold_rel: float = 0.0,
3733
) -> np.ndarray:
3834
"""Build a boolean mask (H, W) of objects from a 2D probability map using peak_local_max.
39-
35+
4036
Args:
4137
img2d (np.ndarray): 2D probability map.
4238
min_distance (int): Minimum distance between peaks.
4339
threshold_abs (float): Absolute threshold for peak detection.
4440
threshold_rel (float, optional): Relative threshold for peak detection. Defaults to 0.0.
41+
4542
Returns:
4643
mask (np.ndarray): Boolean mask (H, W) with True at peak locations.
4744
"""
4845
H, W = img2d.shape
4946
mask = np.zeros((H, W), dtype=bool)
5047
coords = peak_local_max(
51-
img2d, min_distance=min_distance, threshold_abs=threshold_abs, threshold_rel=threshold_rel
48+
img2d,
49+
min_distance=min_distance,
50+
threshold_abs=threshold_abs,
51+
threshold_rel=threshold_rel,
5252
)
5353
if coords.size:
5454
r, c = coords[:, 0], coords[:, 1]
@@ -67,7 +67,7 @@ def peak_detection_mapoverlap(
6767
) -> np.ndarray:
6868
"""Runs inside Dask.da.map_overlap on a padded NumPy block: (h_pad, w_pad, C).
6969
Builds a processed mask per channel, runs peak_local_max then
70-
label+regionprops, and writes probability (mean_intensity) at centroid pixels.
70+
label+regionprops, and writes probability (mean_intensity) at centroid pixels.
7171
Keeps only centroids whose (row,col) lie in the interior window:
7272
rows [depth_h : depth_h + core_h), cols [depth_w : depth_w + core_w)
7373
Returns same spatial shape as input block: (h_pad, w_pad, C), float32.
@@ -81,6 +81,7 @@ def peak_detection_mapoverlap(
8181
depth_w: Halo size in pixels for width (cols).
8282
calculate_probabilities: If True, write mean_intensity at centroids;
8383
else write 1.0 at centroids.
84+
8485
Returns:
8586
out: NumPy array (H, W, C) with probabilities at centroids, 0 elsewhere.
8687
"""
@@ -120,7 +121,9 @@ def peak_detection_mapoverlap(
120121
return out
121122

122123

123-
def detection_with_map_overlap(probs: da.Array, min_distance: int, threshold_abs: float, depth_pixels: int) -> da.Array:
124+
def detection_with_map_overlap(
125+
probs: da.Array, min_distance: int, threshold_abs: float, depth_pixels: int
126+
) -> da.Array:
124127
"""probs: Dask array (H, W, C), float.
125128
depth_pixels: halo in pixels for H/W (use >= min_distance and >= any morphology radius).
126129
@@ -143,18 +146,21 @@ def detection_with_map_overlap(probs: da.Array, min_distance: int, threshold_abs
143146
return scores
144147

145148

146-
def centroids_map_to_dask_dataframe(scores: da.Array, x_offset: int = 0, y_offset: int = 0) -> dd.DataFrame:
149+
def centroids_map_to_dask_dataframe(
150+
scores: da.Array, x_offset: int = 0, y_offset: int = 0
151+
) -> dd.DataFrame:
147152
"""Convert centroid map (H, W, C) into a Dask DataFrame with columns: x, y, type, prob.
148153
149154
Args:
150155
scores: Dask array (H, W, C) with probabilities at centroids, 0 elsewhere.
151156
x_offset: global x offset to add to all x coordinates.
152157
y_offset: global y offset to add to all y coordinates.
158+
153159
Returns:
154160
ddf: Dask DataFrame with columns: x, y, type, prob.
155161
"""
156162
# 1) Build a boolean mask of detections
157-
163+
158164
mask = scores > 0
159165
# 2) Get coordinates and class of detections (lazy 1D Dask arrays)
160166

@@ -172,7 +178,7 @@ def centroids_map_to_dask_dataframe(scores: da.Array, x_offset: int = 0, y_offse
172178
dd.from_dask_array(ss.astype("float32"), columns="prob"),
173179
],
174180
axis=1,
175-
ignore_unknown_divisions=True
181+
ignore_unknown_divisions=True,
176182
)
177183

178184
# 5) Apply global offsets (if needed)
@@ -184,7 +190,9 @@ def centroids_map_to_dask_dataframe(scores: da.Array, x_offset: int = 0, y_offse
184190
return ddf
185191

186192

187-
def nucleus_detection_nms(df: pd.DataFrame, radius: int, overlap_threshold:float = 0.5) -> pd.DataFrame:
193+
def nucleus_detection_nms(
194+
df: pd.DataFrame, radius: int, overlap_threshold: float = 0.5
195+
) -> pd.DataFrame:
188196
"""Greedy NMS across ALL detections.
189197
190198
Keeps the highest-prob detection, removes any other point within 'radius' pixels > overlap_threshold.
@@ -215,7 +223,7 @@ def nucleus_detection_nms(df: pd.DataFrame, radius: int, overlap_threshold:float
215223
coords = sub[["x", "y"]].to_numpy(dtype=np.float64)
216224
r = float(radius)
217225
two_r = 2.0 * r
218-
two_r2 = (two_r * two_r) # distance^2 cutoff for any overlap
226+
two_r2 = two_r * two_r # distance^2 cutoff for any overlap
219227

220228
suppressed = np.zeros(len(sub), dtype=bool)
221229
keep_idx = []
@@ -232,18 +240,19 @@ def nucleus_detection_nms(df: pd.DataFrame, radius: int, overlap_threshold:float
232240
d2 = dx * dx + dy * dy
233241

234242
# Only points with d < 2r can have nonzero overlap
235-
cand = (d2 <= two_r2)
243+
cand = d2 <= two_r2
236244
cand[i] = False # don't suppress the kept point itself
237245
if not np.any(cand):
238246
continue
239247

240248
d = np.sqrt(d2[cand])
241249

242-
243250
# Safe cosine argument = (distance ÷ diameter), Clamp for numerical stability
244251
u = np.clip(d / (2.0 * r), -1.0, 1.0)
245252
# Exact intersection area of two equal-radius circles.
246-
inter = 2.0 * (r * r) * np.arccos(u) - 0.5 * d * np.sqrt(np.clip(4.0 * r * r - d * d, 0.0, None))
253+
inter = 2.0 * (r * r) * np.arccos(u) - 0.5 * d * np.sqrt(
254+
np.clip(4.0 * r * r - d * d, 0.0, None)
255+
)
247256

248257
union = 2.0 * np.pi * (r * r) - inter
249258
iou = inter / union
@@ -252,7 +261,7 @@ def nucleus_detection_nms(df: pd.DataFrame, radius: int, overlap_threshold:float
252261
idx_cand = np.where(cand)[0]
253262
to_suppress = idx_cand[iou >= overlap_threshold]
254263
suppressed[to_suppress] = True
255-
264+
256265
kept = sub.iloc[keep_idx].copy()
257266
return kept
258267

@@ -363,6 +372,7 @@ def post_process_patches(
363372
raw_predictions (da.Array): The raw predictions from the model.
364373
prediction_shape (tuple[int, ...]): The shape of the predictions.
365374
prediction_dtype (type): The data type of the predictions.
375+
366376
Returns:
367377
A list of DataFrames containing the post-processed predictions for each patch.
368378
@@ -376,7 +386,6 @@ def post_process_patches(
376386
batch_predictions.append(self.model.postproc_func(raw_predictions[i]))
377387
return batch_predictions
378388

379-
380389
def post_process_wsi(
381390
self: NucleusDetector,
382391
raw_predictions: da.Array,
@@ -396,8 +405,9 @@ def post_process_wsi(
396405
logger.info(f"Raw probabilities dtype: {prediction_dtype}")
397406
logger.info(f"Chunk size: {raw_predictions.chunks}")
398407

399-
detection_df = self.model.postproc(raw_predictions, prediction_shape, prediction_dtype)
400-
408+
detection_df = self.model.postproc(
409+
raw_predictions, prediction_shape, prediction_dtype
410+
)
401411

402412
return detection_df
403413

@@ -441,11 +451,9 @@ def save_predictions(
441451

442452
save_paths.append(out_file)
443453
return save_paths
444-
else:
445-
return df_to_store_nucleus_detector(
446-
processed_predictions['predictions'],
447-
scale_factor=scale_factor,
448-
save_path=save_path,
449-
class_dict=class_dict,
450-
)
451-
454+
return df_to_store_nucleus_detector(
455+
processed_predictions["predictions"],
456+
scale_factor=scale_factor,
457+
save_path=save_path,
458+
class_dict=class_dict,
459+
)

tiatoolbox/utils/misc.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import zarr
2222
from filelock import FileLock
2323
from shapely.affinity import translate
24-
from shapely.geometry import Polygon, Point
24+
from shapely.geometry import Point, Polygon
2525
from shapely.geometry import shape as feature2geometry
2626
from skimage import exposure
2727
from tqdm import notebook as tqdm_notebook
@@ -1345,18 +1345,17 @@ def df_to_store_nucleus_detector(
13451345
scale_factor: tuple[float, float],
13461346
save_path: Path | None = None,
13471347
class_dict: dict | None = None,
1348-
batch_size: int = 50_000
1348+
batch_size: int = 50_000,
13491349
) -> SQLiteStore | Path:
1350-
"""
1351-
Convert a pandas DataFrame with columns ['x','y','type','prob']
1350+
"""Convert a pandas DataFrame with columns ['x','y','type','prob']
13521351
into an Annotation SQLiteStore efficiently using append_many().
13531352
13541353
Args:
13551354
df (pd.DataFrame):
13561355
A pandas DataFrame with columns ['x','y','type','prob'].
13571356
save_path (Path, optional):
13581357
Optional Output directory to save the Annotation
1359-
Store results.
1358+
Store results.
13601359
scale_factor (tuple[float, float]):
13611360
The scale factor to use when saving the
13621361
annotations. All coordinates will be multiplied by this factor to allow
@@ -1373,7 +1372,6 @@ def df_to_store_nucleus_detector(
13731372
or Path to file storing SQLiteStore containing Annotations
13741373
for each nucleus.
13751374
"""
1376-
13771375
# 1) Select & coerce dtypes once (compact + avoids per-row casts)
13781376
x = df["x"].to_numpy(dtype=np.int64, copy=False)
13791377
y = df["y"].to_numpy(dtype=np.int64, copy=False)
@@ -1387,7 +1385,7 @@ def df_to_store_nucleus_detector(
13871385

13881386
def make_points(xb, yb):
13891387
return [Point(int(xx), int(yy)) for xx, yy in zip(xb, yb)]
1390-
1388+
13911389
if class_dict is None:
13921390
# identity over the actually present types (robust if types aren't 0..K)
13931391
unique_types = np.unique(t)
@@ -1400,9 +1398,16 @@ def make_points(xb, yb):
14001398

14011399
pts = make_points(xb, yb) # array/list of Points
14021400

1403-
anns = [Annotation(geometry=pt,
1404-
properties={"type": class_dict.get(int(tt), int(tt)), "probability": float(pp)})
1405-
for pt, tt, pp in zip(pts, tb, pb)]
1401+
anns = [
1402+
Annotation(
1403+
geometry=pt,
1404+
properties={
1405+
"type": class_dict.get(int(tt), int(tt)),
1406+
"probability": float(pp),
1407+
},
1408+
)
1409+
for pt, tt, pp in zip(pts, tb, pb)
1410+
]
14061411

14071412
store.append_many(anns)
14081413

0 commit comments

Comments
 (0)