Skip to content

Commit 81db81e

Browse files
Update synapse prediction script
1 parent df08168 commit 81db81e

File tree

2 files changed

+56
-14
lines changed

2 files changed

+56
-14
lines changed
Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
1+
import os
2+
from glob import glob
3+
14
import h5py
5+
import imageio.v3 as imageio
26
import napari
37
import zarr
48

59
from torch_em.util import load_model
610
from torch_em.util.prediction import predict_with_halo
711
from train_synapse_detection import get_paths
12+
from tqdm import tqdm
13+
14+
OUTPUT_ROOT = "./predictions"
815

916

1017
def run_prediction(val_image):
@@ -15,22 +22,56 @@ def run_prediction(val_image):
1522
return pred.squeeze()
1623

1724

18-
def main():
19-
val_paths, _ = get_paths("val")
20-
val_image = zarr.open(val_paths[0])["raw"][:]
21-
22-
# pred = run_prediction(val_image)
23-
# with h5py.File("pred.h5", "a") as f:
24-
# f.create_dataset("pred", data=pred, compression="gzip")
25+
def require_prediction(image_data, output_path):
26+
key = "prediction"
27+
if os.path.exists(output_path):
28+
with h5py.File(output_path, "r") as f:
29+
pred = f[key][:]
30+
else:
31+
pred = run_prediction(image_data)
32+
with h5py.File(output_path, "w") as f:
33+
f.create_dataset(key, data=pred, compression="gzip")
34+
return pred
2535

26-
with h5py.File("pred.h5", "r") as f:
27-
pred = f["pred"][:]
2836

37+
def visualize_results(image_data, pred):
2938
v = napari.Viewer()
30-
v.add_image(val_image)
39+
v.add_image(image_data)
3140
v.add_image(pred)
3241
napari.run()
3342

3443

44+
def check_val_image():
45+
val_paths, _ = get_paths("val")
46+
val_path = val_paths[0]
47+
val_image = zarr.open(val_path)["raw"][:]
48+
49+
os.makedirs(os.path.join(OUTPUT_ROOT, "val"), exist_ok=True)
50+
output_path = os.path.join(OUTPUT_ROOT, "val", os.path.basename(val_path).replace(".zarr", ".h5"))
51+
pred = require_prediction(val_image, output_path)
52+
53+
visualize_results(val_image, pred)
54+
55+
56+
def check_new_images():
57+
input_root = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/test_crops"
58+
inputs = glob(os.path.join(input_root, "*.tif"))
59+
output_folder = os.path.join(OUTPUT_ROOT, "new_crops")
60+
os.makedirs(output_folder, exist_ok=True)
61+
for path in tqdm(inputs):
62+
name = os.path.basename(path)
63+
if name == "M_AMD_58L_avgblendfused_RibB.tif":
64+
continue
65+
image_data = imageio.imread(path)
66+
output_path = os.path.join(output_folder, name.replace(".tif", ".h5"))
67+
require_prediction(image_data, output_path)
68+
69+
70+
# TODO update to support post-processing and showing annotations for the val data
71+
def main():
72+
# check_val_image()
73+
check_new_images()
74+
75+
3576
if __name__ == "__main__":
3677
main()

scripts/synapse_marker_detection/train_synapse_detection.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33

44
from detection_dataset import DetectionDataset
55

6-
sys.path.append("/home/pape/Work/my_projects/czii-protein-challenge")
7-
# sys.path.append("/user/pape41/u12086/Work/my_projects/czii-protein-challenge")
6+
# sys.path.append("/home/pape/Work/my_projects/czii-protein-challenge")
7+
sys.path.append("/user/pape41/u12086/Work/my_projects/czii-protein-challenge")
88

99
from utils.training.training import supervised_training # noqa
1010

11-
TRAIN_ROOT = "./training_data/images"
12-
LABEL_ROOT = "./training_data/labels"
11+
ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/synapses/training_data/v1" # noqa
12+
TRAIN_ROOT = os.path.join(ROOT, "images")
13+
LABEL_ROOT = os.path.join(ROOT, "labels")
1314

1415

1516
def get_paths(split):

0 commit comments

Comments
 (0)