Skip to content

Commit d7f71a0

Browse files
Update AZ training data
1 parent 39d69d5 commit d7f71a0

File tree

11 files changed

+819
-31
lines changed

11 files changed

+819
-31
lines changed
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
import os
2+
from glob import glob
3+
4+
import h5py
5+
import numpy as np
6+
from tqdm import tqdm
7+
from skimage.transform import resize
8+
9+
ROOT = "/mnt/ceph-hdd/cold_store/projects/nim00007/AZ_data/training_data"
10+
INTER_ROOT = "/mnt/ceph-hdd/cold_store/projects/nim00007/AZ_predictions"
11+
OUTPUT_ROOT = "/mnt/ceph-hdd/cold_store/projects/nim00007/new_AZ_train_data"
12+
13+
14+
def _check_data(files, label_folder, check_thinned):
15+
for ff in files:
16+
with h5py.File(ff, "r") as f:
17+
shape = f["raw"].shape
18+
az = f["labels/az"][:]
19+
n_az = az.max()
20+
21+
if check_thinned:
22+
label_file = os.path.join(label_folder, os.path.basename(ff))
23+
with h5py.File(label_file, "r") as f:
24+
az_thin = f["labels/az_thin2"][:]
25+
n_az_thin = az_thin.max()
26+
else:
27+
n_az_thin = None
28+
29+
print(os.path.basename(ff), ":", shape, ":", n_az, ":", n_az_thin)
30+
31+
32+
def assort_tem():
33+
old_name = "01data_withoutInvertedFiles_minusSVseg_corrected"
34+
new_name = "tem"
35+
36+
raw_folder = os.path.join(ROOT, old_name)
37+
label_folder = os.path.join(INTER_ROOT, old_name)
38+
output_folder = os.path.join(OUTPUT_ROOT, new_name)
39+
os.makedirs(output_folder, exist_ok=True)
40+
41+
files = glob(os.path.join(raw_folder, "*.h5"))
42+
for ff in tqdm(files):
43+
with h5py.File(ff, "r") as f:
44+
raw = f["raw"][:]
45+
az = f["labels/az"][:]
46+
47+
label_path = os.path.join(label_folder, os.path.basename(ff))
48+
with h5py.File(label_path, "r") as f:
49+
az_thin = f["labels/az_thin2"][:]
50+
51+
z_range1 = np.where(az != 0)[0]
52+
z_range2 = np.where(az != 0)[0]
53+
z_range = slice(
54+
np.min(np.concatenate([z_range1, z_range2])),
55+
np.max(np.concatenate([z_range1, z_range2])) + 1,
56+
)
57+
raw, az, az_thin = raw[z_range], az[z_range], az_thin[z_range]
58+
59+
out_path = os.path.join(output_folder, os.path.basename(ff))
60+
with h5py.File(out_path, "a") as f:
61+
f.create_dataset("raw", data=raw, compression="lzf")
62+
f.create_dataset("labels/az_thin", data=az_thin, compression="lzf")
63+
f.create_dataset("labels/az", data=az, compression="lzf")
64+
65+
66+
def assort_chemical_fixation():
67+
old_name = "12_chemical_fix_cryopreparation_minusSVseg_corrected"
68+
new_name = "chemical_fixation"
69+
70+
raw_folder = os.path.join(ROOT, old_name)
71+
label_folder = os.path.join(INTER_ROOT, old_name)
72+
output_folder = os.path.join(OUTPUT_ROOT, new_name)
73+
os.makedirs(output_folder, exist_ok=True)
74+
75+
label_key = "labels/az_thin2"
76+
77+
files = glob(os.path.join(raw_folder, "*.h5"))
78+
for ff in tqdm(files):
79+
with h5py.File(ff, "r") as f:
80+
raw = f["raw"][:]
81+
az = f["labels/az"][:]
82+
83+
label_path = os.path.join(label_folder, os.path.basename(ff))
84+
with h5py.File(label_path, "r") as f:
85+
az_thin = f[label_key][:]
86+
87+
z_range1 = np.where(az != 0)[0]
88+
z_range2 = np.where(az != 0)[0]
89+
z_range = slice(
90+
np.min(np.concatenate([z_range1, z_range2])),
91+
np.max(np.concatenate([z_range1, z_range2])) + 1,
92+
)
93+
raw, az, az_thin = raw[z_range], az[z_range], az_thin[z_range]
94+
95+
out_path = os.path.join(output_folder, os.path.basename(ff))
96+
with h5py.File(out_path, "a") as f:
97+
f.create_dataset("raw", data=raw, compression="lzf")
98+
f.create_dataset("labels/az_thin", data=az_thin, compression="lzf")
99+
f.create_dataset("labels/az", data=az, compression="lzf")
100+
101+
102+
def assort_stem():
103+
old_names = [
104+
"04_hoi_stem_examples_fidi_and_sarah_corrected",
105+
"04_hoi_stem_examples_minusSVseg_cropped_corrected",
106+
"06_hoi_wt_stem750_fm_minusSVseg_cropped_corrected",
107+
]
108+
new_names = ["stem", "stem_cropped", "stem_cropped"]
109+
for old_name, new_name in zip(old_names, new_names):
110+
print(old_name)
111+
raw_folder = os.path.join(ROOT, f"{old_name}_rescaled_tomograms")
112+
label_folder = os.path.join(INTER_ROOT, old_name)
113+
files = glob(os.path.join(raw_folder, "*.h5"))
114+
115+
# _check_data(files, label_folder, check_thinned=True)
116+
# continue
117+
118+
output_folder = os.path.join(OUTPUT_ROOT, new_name)
119+
os.makedirs(output_folder, exist_ok=True)
120+
for ff in tqdm(files):
121+
with h5py.File(ff, "r") as f:
122+
raw = f["raw"][:]
123+
az = f["labels/az"][:]
124+
125+
label_path = os.path.join(label_folder, os.path.basename(ff))
126+
with h5py.File(label_path, "r") as f:
127+
az_thin = f["labels/az_thin2"][:]
128+
az_thin = resize(az_thin, az.shape, order=0, anti_aliasing=False, preserve_range=True).astype(az_thin.dtype)
129+
assert az_thin.shape == az.shape
130+
131+
out_path = os.path.join(output_folder, os.path.basename(ff))
132+
with h5py.File(out_path, "a") as f:
133+
f.create_dataset("raw", data=raw, compression="lzf")
134+
f.create_dataset("labels/az_thin", data=az_thin, compression="lzf")
135+
f.create_dataset("labels/az", data=az, compression="lzf")
136+
137+
138+
def assort_wichmann():
139+
old_name = "wichmann_withAZ_rescaled_tomograms"
140+
new_name = "endbulb_of_held"
141+
142+
raw_folder = os.path.join(ROOT, old_name)
143+
output_folder = os.path.join(OUTPUT_ROOT, new_name)
144+
os.makedirs(output_folder, exist_ok=True)
145+
146+
files = glob(os.path.join(raw_folder, "*.h5"))
147+
148+
output_folder = os.path.join(OUTPUT_ROOT, new_name)
149+
os.makedirs(output_folder, exist_ok=True)
150+
for ff in tqdm(files):
151+
with h5py.File(ff, "r") as f:
152+
raw = f["raw"][:]
153+
az = f["labels/az"][:]
154+
155+
output_file = os.path.join(output_folder, os.path.basename(ff))
156+
with h5py.File(output_file, "a") as f:
157+
f.create_dataset("raw", data=raw, compression="lzf")
158+
f.create_dataset("labels/az", data=az, compression="lzf")
159+
f.create_dataset("labels/az_thin", data=az, compression="lzf")
160+
161+
162+
def crop_wichmann():
163+
input_name = "endbulb_of_held"
164+
output_name = "endbulb_of_held_cropped"
165+
166+
input_folder = os.path.join(OUTPUT_ROOT, input_name)
167+
output_folder = os.path.join(OUTPUT_ROOT, output_name)
168+
os.makedirs(output_folder, exist_ok=True)
169+
files = glob(os.path.join(input_folder, "*.h5"))
170+
171+
min_shape = (32, 512, 512)
172+
173+
for ff in tqdm(files):
174+
with h5py.File(ff, "r") as f:
175+
az = f["labels/az"][:]
176+
bb = np.where(az != 0)
177+
bb = tuple(slice(int(b.min()), int(b.max()) + 1) for b in bb)
178+
pad_width = [max(sh - (b.stop - b.start), 0) // 2 for b, sh in zip(bb, min_shape)]
179+
bb = tuple(
180+
slice(max(b.start - pw, 0), min(b.stop + pw, sh)) for b, pw, sh in zip(bb, pad_width, az.shape)
181+
)
182+
az = az[bb]
183+
raw = f["raw"][bb]
184+
185+
# import napari
186+
# v = napari.Viewer()
187+
# v.add_image(raw)
188+
# v.add_labels(az)
189+
# v.add_labels(az_thin)
190+
# napari.run()
191+
192+
output_path = os.path.join(output_folder, os.path.basename(ff).replace(".h5", "_cropped.h5"))
193+
with h5py.File(output_path, "a") as f:
194+
f.create_dataset("raw", data=raw, compression="lzf")
195+
f.create_dataset("labels/az", data=az, compression="lzf")
196+
f.create_dataset("labels/az_thin", data=az, compression="lzf")
197+
198+
199+
def main():
200+
# assort_tem()
201+
# assort_chemical_fixation()
202+
203+
# assort_stem()
204+
205+
# assort_wichmann()
206+
crop_wichmann()
207+
208+
209+
if __name__ == "__main__":
210+
main()
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import argparse
2+
import os
3+
from glob import glob
4+
5+
import napari
6+
import h5py
7+
8+
ROOT = "/mnt/ceph-hdd/cold_store/projects/nim00007/new_AZ_train_data"
9+
all_names = [
10+
"chemical_fixation",
11+
"tem",
12+
"stem",
13+
"stem_cropped",
14+
"endbulb_of_held",
15+
"endbulb_of_held_cropped",
16+
]
17+
18+
19+
parser = argparse.ArgumentParser()
20+
parser.add_argument("-n", "--names", nargs="+", default=all_names)
21+
args = parser.parse_args()
22+
names = args.names
23+
24+
25+
for ds in names:
26+
paths = glob(os.path.join(ROOT, ds, "*.h5"))
27+
for p in paths:
28+
with h5py.File(p, "r") as f:
29+
raw = f["raw"][:]
30+
az = f["labels/az"][:]
31+
az_thin = f["labels/az_thin"][:]
32+
v = napari.Viewer()
33+
v.add_image(raw)
34+
v.add_labels(az)
35+
v.add_labels(az_thin)
36+
v.title = os.path.basename(p)
37+
napari.run()

scripts/cooper/revision/eval_AZ.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
python run_az_evaluation.py \
2+
-s /mnt/ceph-hdd/cold_store/projects/nim00007/AZ_data/segmentations \
3+
-g /mnt/ceph-hdd/cold_store/projects/nim00007/AZ_data/training_data \
4+
--seg_key /AZ/segment_from_AZmodel_TEM_STEM_ChemFix_v1 \
5+
--criterion iop \
6+
-o v1
7+
# --dataset 01 \
8+
# --seg_key AZ/segment_from_AZmodel_TEM_STEM_ChemFix_wichmann_v2 \
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import argparse
2+
import pandas as pd
3+
4+
parser = argparse.ArgumentParser()
5+
parser.add_argument("result_path")
6+
args = parser.parse_args()
7+
8+
results = pd.read_excel(args.result_path)
9+
10+
11+
def summarize_results(res):
12+
print("Dice-Score:", res["dice"].mean(), "+-", res["dice"].std())
13+
tp, fp, fn = float(res["tp"].sum()), float(res["fp"].sum()), float(res["fn"].sum())
14+
precision = tp / (tp + fp)
15+
recall = tp / (tp + fn)
16+
f1_score = 2 * tp / (2 * tp + fn + fp)
17+
print("Precision:", precision)
18+
print("Recall:", recall)
19+
print("F1-Score:", f1_score)
20+
21+
22+
# # Compute the results for Chemical Fixation.
23+
results_chem_fix = results[results.dataset.str.startswith("12")]
24+
if results_chem_fix.size > 0:
25+
print("Chemical Fixation Results:")
26+
summarize_results(results_chem_fix)
27+
#
28+
# # Compute the results for STEM (=04).
29+
results_stem = results[results.dataset.str.startswith(("04", "06"))]
30+
if results_stem.size > 0:
31+
print()
32+
print("STEM Results:")
33+
summarize_results(results_stem)
34+
#
35+
# # Compute the results for TEM (=01).
36+
results_tem = results[results.dataset.str.startswith("01")]
37+
if results_tem.size > 0:
38+
print()
39+
print("TEM Results:")
40+
summarize_results(results_tem)
41+
42+
#
43+
# Compute the results for Wichmann.
44+
results_wichmann = results[results.dataset.str.startswith("wichmann")]
45+
if results_wichmann.size > 0:
46+
print()
47+
print("Wichmann Results:")
48+
summarize_results(results_wichmann)
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import os
2+
from glob import glob
3+
4+
import h5py
5+
from synapse_net.inference.inference import get_model, compute_scale_from_voxel_size
6+
from synapse_net.inference.compartments import segment_compartments
7+
from synapse_net.inference.vesicles import segment_vesicles
8+
from tqdm import tqdm
9+
10+
ROOT = "/mnt/ceph-hdd/cold_store/projects/nim00007/AZ_data/training_data"
11+
OUTPUT_ROOT = "/mnt/ceph-hdd/cold_store/projects/nim00007/AZ_predictions"
12+
RESOLUTIONS = {
13+
"01data_withoutInvertedFiles_minusSVseg_corrected": {"x": 1.554, "y": 1.554, "z": 1.554},
14+
"04_hoi_stem_examples_fidi_and_sarah_corrected": {"x": 0.8681, "y": 0.8681, "z": 0.8681},
15+
"04_hoi_stem_examples_fidi_and_sarah_corrected_rescaled_tomograms": {"x": 1.554, "y": 1.554, "z": 1.554},
16+
"04_hoi_stem_examples_minusSVseg_cropped_corrected": {"x": 0.8681, "y": 0.8681, "z": 0.8681},
17+
"04_hoi_stem_examples_minusSVseg_cropped_corrected_rescaled_tomograms": {"x": 1.554, "y": 1.554, "z": 1.554},
18+
"06_hoi_wt_stem750_fm_minusSVseg_cropped_corrected": {"x": 0.8681, "y": 0.8681, "z": 0.8681},
19+
"06_hoi_wt_stem750_fm_minusSVseg_cropped_corrected_rescaled_tomograms": {"x": 1.554, "y": 1.554, "z": 1.554},
20+
"12_chemical_fix_cryopreparation_minusSVseg_corrected": {"x": 1.554, "y": 1.554, "z": 1.554},
21+
"wichmann_withAZ": {"x": 1.748, "y": 1.748, "z": 1.748},
22+
"wichmann_withAZ_rescaled_tomograms": {"x": 1.554, "y": 1.554, "z": 1.554},
23+
}
24+
25+
26+
def predict_boundaries(model, path, output_path):
27+
output_key = "predictions/boundaries"
28+
if os.path.exists(output_path):
29+
with h5py.File(output_path, "r") as f:
30+
if output_key in f:
31+
return
32+
33+
dataset = os.path.basename(os.path.split(path)[0])
34+
35+
with h5py.File(path, "r") as f:
36+
data = f["raw"][:]
37+
scale = compute_scale_from_voxel_size(RESOLUTIONS[dataset], "compartments")
38+
_, pred = segment_compartments(data, model=model, scale=scale, verbose=False, return_predictions=True)
39+
with h5py.File(output_path, "a") as f:
40+
f.create_dataset(output_key, data=pred, compression="lzf")
41+
42+
43+
def predict_all_boundaries():
44+
model = get_model("compartments")
45+
files = sorted(glob(os.path.join(ROOT, "**/*.h5"), recursive=True))
46+
for path in tqdm(files):
47+
folder_name = os.path.basename(os.path.split(path)[0])
48+
output_folder = os.path.join(OUTPUT_ROOT, folder_name)
49+
os.makedirs(output_folder, exist_ok=True)
50+
output_path = os.path.join(output_folder, os.path.basename(path))
51+
predict_boundaries(model, path, output_path)
52+
53+
54+
def predict_vesicles(model, path, output_path):
55+
output_key = "predictions/vesicle_seg"
56+
if os.path.exists(output_path):
57+
with h5py.File(output_path, "r") as f:
58+
if output_key in f:
59+
return
60+
61+
dataset = os.path.basename(os.path.split(path)[0])
62+
if "rescaled" in dataset:
63+
return
64+
65+
with h5py.File(path, "r") as f:
66+
data = f["raw"][:]
67+
scale = compute_scale_from_voxel_size(RESOLUTIONS[dataset], "vesicles_3d")
68+
seg = segment_vesicles(data, model=model, scale=scale, verbose=False)
69+
with h5py.File(output_path, "a") as f:
70+
f.create_dataset(output_key, data=seg, compression="lzf")
71+
72+
73+
def predict_all_vesicles():
74+
model = get_model("vesicles_3d")
75+
files = sorted(glob(os.path.join(ROOT, "**/*.h5"), recursive=True))
76+
for path in tqdm(files):
77+
folder_name = os.path.basename(os.path.split(path)[0])
78+
output_folder = os.path.join(OUTPUT_ROOT, folder_name)
79+
os.makedirs(output_folder, exist_ok=True)
80+
output_path = os.path.join(output_folder, os.path.basename(path))
81+
predict_vesicles(model, path, output_path)
82+
83+
84+
def main():
85+
# predict_all_boundaries()
86+
predict_all_vesicles()
87+
88+
89+
if __name__ == "__main__":
90+
main()

0 commit comments

Comments
 (0)