Skip to content

Commit 7507d91

Browse files
Implement more debugging for detection model
1 parent 66ed39b commit 7507d91

File tree

4 files changed

+183
-0
lines changed

4 files changed

+183
-0
lines changed
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import napari
2+
import zarr
3+
4+
5+
resolution = [3.0, 1.887779, 1.887779]
6+
positions = [
7+
[2002.95539395823, 1899.9032205156411, 264.7747008147759]
8+
]
9+
10+
11+
def _load_from_mobie(bb):
12+
path = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet/LaVision-M04/images/ome-zarr/PV.ome.zarr"
13+
f = zarr.open(path, mode="r")
14+
data = f["s0"][bb]
15+
print(bb)
16+
17+
path = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet/LaVision-M04/images/ome-zarr/SGN_detect-v1.ome.zarr"
18+
f = zarr.open(path, mode="r")
19+
seg = f["s0"][bb]
20+
21+
return data, seg
22+
23+
24+
def _load_prediction(bb):
25+
path = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/predictions/LaVision-M04/SGN_detect-v1/predictions.zarr"
26+
f = zarr.open(path, mode="r")
27+
data = f["prediction"][bb]
28+
return data
29+
30+
31+
def _load_prediction_debug():
32+
path = "./debug-pred/pred-v5.h5"
33+
with zarr.open(path, "r") as f:
34+
pred = f["pred"][:]
35+
return pred
36+
37+
38+
def check_detection(position, halo=[32, 384, 384]):
39+
40+
bb = tuple(
41+
slice(int(pos / re) - ha, int(pos / re) + ha) for pos, re, ha in zip(position[::-1], resolution, halo)
42+
)
43+
44+
pv, detections_mobie = _load_from_mobie(bb)
45+
# pred = _load_prediction(bb)
46+
pred = _load_prediction_debug()
47+
48+
v = napari.Viewer()
49+
v.add_image(pv)
50+
v.add_image(pred)
51+
v.add_labels(detections_mobie)
52+
napari.run()
53+
54+
55+
def main():
56+
position = positions[0]
57+
check_detection(position)
58+
59+
60+
if __name__ == "__main__":
61+
main()
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import os
2+
from functools import partial
3+
4+
import numpy as np
5+
import torch
6+
import zarr
7+
from torch_em.transform.raw import standardize
8+
from torch_em.util.prediction import predict_with_halo
9+
10+
11+
resolution = [3.0, 1.887779, 1.887779]
12+
positions = [
13+
[2002.95539395823, 1899.9032205156411, 264.7747008147759]
14+
]
15+
16+
17+
def _load_from_mobie(bb):
18+
path = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/mobie_project/cochlea-lightsheet/LaVision-M04/images/ome-zarr/PV.ome.zarr"
19+
f = zarr.open(path, mode="r")
20+
data = f["s0"][bb]
21+
return data
22+
23+
24+
def run_prediction(position, halo=[32, 384, 384]):
25+
bb = tuple(
26+
slice(int(pos / re) - ha, int(pos / re) + ha) for pos, re, ha in zip(position[::-1], resolution, halo)
27+
)
28+
pv = _load_from_mobie(bb)
29+
mean, std = np.mean(pv), np.std(pv)
30+
print(mean, std)
31+
preproc = partial(standardize, mean=mean, std=std)
32+
33+
block_shape = (24, 256, 256)
34+
halo = (8, 64, 64)
35+
36+
model_path = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/SGN/sgn-detection-v1.pt"
37+
model = torch.load(model_path, weights_only=False)
38+
39+
def postproc(x):
40+
x = np.clip(x, 0, 1)
41+
max_ = np.percentile(x, 99)
42+
x = x / max_
43+
return x
44+
45+
pred = predict_with_halo(pv, model, [0], block_shape, halo, preprocess=preproc, postprocess=postproc).squeeze()
46+
47+
pred_name = "pred-v5"
48+
out_folder = "./debug-pred"
49+
os.makedirs(out_folder, exist_ok=True)
50+
51+
out_path = os.path.join(out_folder, f"{pred_name}.h5")
52+
with zarr.open(out_path, "w") as f:
53+
f.create_dataset("pred", data=pred)
54+
55+
56+
def main():
57+
position = positions[0]
58+
run_prediction(position)
59+
60+
61+
if __name__ == "__main__":
62+
main()

scripts/la-vision/detect_blocks.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import os
2+
import imageio.v3 as imageio
3+
from pathlib import Path
4+
5+
import pandas as pd
6+
import torch
7+
from skimage.feature import peak_local_max
8+
from torch_em.util.prediction import predict_with_halo
9+
10+
ims = [
11+
"/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/SGN/sgn-detection/images/LaVision-M04_crop_2580-2266-0533_PV.tif",
12+
"/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/SGN/sgn-detection/empty_images/LaVision-M04_crop_0400-2500-0840_PV_empty.tif"
13+
]
14+
15+
model_path = "checkpoints/sgn-detection.pt"
16+
model = torch.load(model_path, weights_only=False)
17+
18+
block_shape = [24, 256, 256]
19+
halo = (8, 64, 64)
20+
21+
out = "./detections-v1"
22+
os.makedirs(out, exist_ok=True)
23+
for im in ims:
24+
data = imageio.imread(im)
25+
pred = predict_with_halo(data, model, [0], block_shape, halo).squeeze()
26+
27+
coords = peak_local_max(pred, min_distance=4, threshold_abs=0.5)
28+
29+
# coords = np.concatenate([np.arange(0, len(coords))[:, None], coords], axis=1)
30+
coords = pd.DataFrame(coords, columns=["axis-0", "axis-1", "axis-2"])
31+
32+
name = Path(im).stem
33+
imageio.imwrite(os.path.join(out, f"{name}.tif"), pred)
34+
coords.to_csv(os.path.join(out, f"{name}.csv"), index=False)

scripts/la-vision/export_model.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import argparse
2+
import sys
3+
4+
import torch
5+
from torch_em.util import load_model
6+
7+
sys.path.append("/home/pape/Work/my_projects/czii-protein-challenge")
8+
sys.path.append("/user/pape41/u12086/Work/my_projects/czii-protein-challenge")
9+
sys.path.append("../synapse_marker_detection")
10+
11+
12+
def export_model(input_, output):
13+
model = load_model(input_, device="cpu")
14+
torch.save(model, output)
15+
16+
17+
def main():
18+
parser = argparse.ArgumentParser()
19+
parser.add_argument("-i", "--input", required=True)
20+
parser.add_argument("-o", "--output", required=True)
21+
args = parser.parse_args()
22+
export_model(args.input, args.output)
23+
24+
25+
if __name__ == "__main__":
26+
main()

0 commit comments

Comments
 (0)