Skip to content

Commit 6da20c5

Browse files
Add script to check synapse prediction
1 parent 9db6a83 commit 6da20c5

File tree

1 file changed

+25
-0
lines changed

1 file changed

+25
-0
lines changed
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import h5py
2+
import zarr
3+
from torch_em.util import load_model
4+
from torch_em.util.prediction import predict_with_halo
5+
from train_synapse_detection import get_paths
6+
7+
8+
def run_prediction(val_image):
9+
model = load_model("./checkpoints/synapse_detection_v1")
10+
block_shape = (32, 384, 384)
11+
halo = (8, 64, 64)
12+
pred = predict_with_halo(val_image, model, [0], block_shape, halo)
13+
return pred.squeeze()
14+
15+
16+
def main():
17+
val_paths, _ = get_paths("val")
18+
val_image = zarr.open(val_paths[0])["raw"][:]
19+
pred = run_prediction(val_image)
20+
with h5py.File("pred.h5", "a") as f:
21+
f.create_dataset("pred", data=pred, compression="gzip")
22+
23+
24+
if __name__ == "__main__":
25+
main()

0 commit comments

Comments
 (0)