Skip to content

Commit 48618d0

Browse files
committed
added compartment seg saved in h5
1 parent f17c349 commit 48618d0

File tree

2 files changed

+109
-5
lines changed

2 files changed

+109
-5
lines changed

run_sbatch_revision.sbatch

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@
22
#SBATCH -c 4 #4 #8
33
#SBATCH --mem 256G #120G #32G #64G #256G
44
#SBATCH -p grete:shared #grete:shared #grete-h100:shared
5-
#SBATCH -t 6:00:00 #6:00:00 #48:00:00
5+
#SBATCH -t 3:00:00 #6:00:00 #48:00:00
66
#SBATCH -G A100:1 #V100:1 #2 #A100:1 #gtx1080:2 #v100:1 #H100:1
77
#SBATCH --output=/user/muth9/u12095/synapse-net/slurm_revision/slurm-%j.out
88
#SBATCH -A nim00007
99
#SBATCH --constraint 80gb
1010

1111
source ~/.bashrc
1212
conda activate synapse-net
13-
python /user/muth9/u12095/synapse-net/scripts/cooper/revision/updated_data_analysis/run_data_analysis.py \
14-
-i /mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/20241102_TOMO_DATA_Imig2014/exported/SNAP25/ \
15-
-o /mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/cooper/20241102_TOMO_DATA_Imig2014/afterRevision_analysis/boundaryT0_9_constantins_presynapticFiltering --store \
16-
-s ./analysis_results/man_subset
13+
python /user/muth9/u12095/synapse-net/scripts/cooper/run_compartment_segmentation_h5.py \
14+
-i /mnt/ceph-hdd/cold/nim00007/new_AZ_train_data/stem_for_eval/20241019_Tomo-eval_PS_Synapse_36859_J1_66K_TS_CA3_PS_46_rec_2Kb1dawbp_crop.h5 \
15+
--data_ext .h5
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import argparse
2+
from functools import partial
3+
4+
from synapse_net.inference.compartments import segment_compartments
5+
from synapse_net.inference.inference import get_model_path
6+
from synapse_net.inference.util import inference_helper, parse_tiling
7+
8+
import h5py
9+
import numpy as np
10+
from elf.io import open_file
11+
12+
def get_volume(input_path):
13+
'''
14+
with h5py.File(input_path) as seg_file:
15+
input_volume = seg_file["raw"][:]
16+
'''
17+
with open_file(input_path, "r") as f:
18+
19+
# Try to automatically derive the key with the raw data.
20+
keys = list(f.keys())
21+
if len(keys) == 1:
22+
key = keys[0]
23+
elif "data" in keys:
24+
key = "data"
25+
elif "raw" in keys:
26+
key = "raw"
27+
28+
input_volume = f[key][:]
29+
return input_volume
30+
31+
def run_compartment_segmentation(args):
32+
tiling = parse_tiling(args.tile_shape, args.halo)
33+
34+
if args.model is None:
35+
model_path = get_model_path("compartments")
36+
else:
37+
model_path = args.model
38+
39+
# Call segment_compartments directly, since we need its outputs
40+
segmentation, predictions = segment_compartments(
41+
get_volume(args.input_path),
42+
model_path=model_path,
43+
verbose=True,
44+
tiling=tiling,
45+
scale=None,
46+
boundary_threshold=args.boundary_threshold,
47+
return_predictions=True
48+
)
49+
50+
# Save outputs into input HDF5 file
51+
with h5py.File(args.input_path, "a") as f:
52+
pred_grp = f.require_group("predictions")
53+
54+
if "comp_seg" in pred_grp:
55+
if args.force:
56+
del pred_grp["comp_seg"]
57+
else:
58+
raise RuntimeError("comp_seg already exists. Use --force to overwrite.")
59+
pred_grp.create_dataset("comp_seg", data=segmentation.astype(np.uint8), compression="gzip")
60+
61+
if "boundaries" in pred_grp:
62+
if args.force:
63+
del pred_grp["boundaries"]
64+
else:
65+
raise RuntimeError("boundaries already exist. Use --force to overwrite.")
66+
pred_grp.create_dataset("boundaries", data=predictions.astype(np.float32), compression="gzip")
67+
68+
print(f"Saved segmentation to: predictions/comp_seg")
69+
print(f"Saved boundaries to: predictions/boundaries")
70+
71+
72+
def main():
73+
parser = argparse.ArgumentParser(description="Segment synaptic compartments in EM tomograms.")
74+
parser.add_argument(
75+
"--input_path", "-i", required=True,
76+
help="The filepath to mrc file or directory containing the tomogram data."
77+
)
78+
parser.add_argument(
79+
"--model", "-m", help="The filepath to the compartment model."
80+
)
81+
parser.add_argument(
82+
"--force", action="store_true",
83+
help="Whether to over-write already present segmentation results."
84+
)
85+
parser.add_argument(
86+
"--tile_shape", type=int, nargs=3,
87+
help="The tile shape for prediction. Lower the tile shape if GPU memory is insufficient."
88+
)
89+
parser.add_argument(
90+
"--halo", type=int, nargs=3,
91+
help="The halo for prediction. Increase the halo to minimize boundary artifacts."
92+
)
93+
parser.add_argument(
94+
"--data_ext", default=".mrc", help="The extension of the tomogram data. By default .mrc."
95+
)
96+
parser.add_argument(
97+
"--boundary_threshold", type=float, default=0.4, help="Threshold that determines when the prediction of the network is foreground for the segmentation. Need higher threshold than default for TEM."
98+
)
99+
100+
args = parser.parse_args()
101+
run_compartment_segmentation(args)
102+
103+
104+
if __name__ == "__main__":
105+
main()

0 commit comments

Comments
 (0)