Skip to content

Commit b788da7

Browse files
Update prediction check
1 parent df08168 commit b788da7

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

scripts/synapse_marker_detection/check_synapse_prediction.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import h5py
22
import napari
3+
import pandas as pd
34
import zarr
45

6+
# from skimage.feature import blob_dog
7+
from skimage.feature import peak_local_max
58
from torch_em.util import load_model
69
from torch_em.util.prediction import predict_with_halo
710
from train_synapse_detection import get_paths
@@ -16,8 +19,9 @@ def run_prediction(val_image):
1619

1720

1821
def main():
19-
val_paths, _ = get_paths("val")
22+
val_paths, val_labels = get_paths("val")
2023
val_image = zarr.open(val_paths[0])["raw"][:]
24+
val_labels = pd.read_csv(val_labels[0])[["axis-0", "axis-1", "axis-2"]]
2125

2226
# pred = run_prediction(val_image)
2327
# with h5py.File("pred.h5", "a") as f:
@@ -26,9 +30,17 @@ def main():
2630
with h5py.File("pred.h5", "r") as f:
2731
pred = f["pred"][:]
2832

33+
print("Running local max ...")
34+
# coords = blob_dog(pred)
35+
coords = peak_local_max(pred, min_distance=2, threshold_abs=0.2)
36+
# breakpoint()
37+
print("... done")
38+
2939
v = napari.Viewer()
3040
v.add_image(val_image)
3141
v.add_image(pred)
42+
v.add_points(coords)
43+
v.add_points(val_labels, face_color="green")
3244
napari.run()
3345

3446

0 commit comments

Comments
 (0)