Skip to content

Commit f8d801f

Browse files
Clean up AZ scripts and summarize current state
1 parent d7f71a0 commit f8d801f

File tree

16 files changed

+431
-128
lines changed

16 files changed

+431
-128
lines changed

scripts/cooper/revision/README.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Improving the AZ model
2+
3+
Scripts for improving the AZ annotations, training the AZ model, and evaluating it.
4+
5+
The most important scripts are:
6+
- For improving and updating the AZ annotations:
7+
- `prediction.py`: Run prediction of vesicle and boundary model.
8+
- `thin_az_gt.py`: Thin the AZ annotations, so that it aligns only with the presynaptic membrane. This is done by intersecting the annotations with the presynaptic compartment, using predictions from the network used for compartment segmentation.
9+
- `assort_new_az_data.py`: Create a new version of the annotation, renaming the dataset, and creating a cropped version of the endbulb of held data.
10+
- `merge_az.py`: Merge AZ annotations with predictions from model v4, in order to remove some artifacts that resulted from AZ thinning.
11+
- For evaluating the AZ predictions:
12+
- `az_prediction.py`: Run prediction with the AZ model.
13+
- `run_az_evaluation.py`: Evaluate the predictions of an AZ model.
14+
- `evaluate_result.py`: Summarize the evaluation results.
15+
- And for training: `train_az_gt.py`. So far, I have trained:
16+
- v3: Trained on the initial annotations.
17+
- v4: Trained on the thinned annotations.
18+
- v5: Trained on the thinned annotations with an additional distance loss (did not help).
19+
- v6: Trained on the merged annotations.
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import argparse
2+
import os
3+
4+
import h5py
5+
from synapse_net.inference.active_zone import segment_active_zone
6+
from torch_em.util import load_model
7+
from tqdm import tqdm
8+
9+
from common import get_file_names, get_split_folder, ALL_NAMES, INPUT_ROOT, OUTPUT_ROOT
10+
11+
12+
def run_prediction(model, name, split_folder, version, split_names):
13+
file_names = get_file_names(name, split_folder, split_names=split_names)
14+
15+
output_folder = os.path.join(OUTPUT_ROOT, name)
16+
os.makedirs(output_folder, exist_ok=True)
17+
output_key = f"predictions/az/v{version}"
18+
19+
for fname in tqdm(file_names):
20+
output_path = os.path.join(output_folder, fname)
21+
22+
if os.path.exists(output_path):
23+
with h5py.File(output_path, "r") as f:
24+
if output_key in f:
25+
continue
26+
27+
input_path = os.path.join(INPUT_ROOT, name, fname)
28+
with h5py.File(input_path, "r") as f:
29+
raw = f["raw"][:]
30+
31+
_, pred = segment_active_zone(raw, model=model, verbose=False, return_predictions=True)
32+
with h5py.File(output_path, "a") as f:
33+
f.create_dataset(output_key, data=pred, compression="lzf")
34+
35+
36+
def get_model(version):
37+
assert version in (3, 4, 5)
38+
split_folder = get_split_folder(version)
39+
if version == 3:
40+
model_path = os.path.join(split_folder, "checkpoints", "3D-AZ-model-TEM_STEM_ChemFix_wichmann-v3")
41+
else:
42+
model_path = os.path.join(split_folder, "checkpoints", f"v{version}")
43+
model = load_model(model_path)
44+
return model
45+
46+
47+
def main():
48+
parser = argparse.ArgumentParser()
49+
parser.add_argument("--version", "-v", type=int)
50+
parser.add_argument("--names", nargs="+", default=ALL_NAMES)
51+
parser.add_argument("--splits", nargs="+", default=["test"])
52+
args = parser.parse_args()
53+
54+
model = get_model(args.version)
55+
split_folder = get_split_folder(args.version)
56+
for name in args.names:
57+
run_prediction(model, name, split_folder, args.version, args.splits)
58+
59+
60+
if __name__ == "__main__":
61+
main()
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
import argparse
2+
import os
3+
4+
import h5py
5+
import napari
6+
from common import ALL_NAMES, get_file_names, get_split_folder, get_paths
7+
8+
9+
def check_predictions(name, split, version):
10+
split_folder = get_split_folder(version)
11+
file_names = get_file_names(name, split_folder, split_names=[split])
12+
seg_paths, gt_paths = get_paths(name, file_names)
13+
14+
for seg_path, gt_path in zip(seg_paths, gt_paths):
15+
16+
with h5py.File(gt_path, "r") as f:
17+
raw = f["raw"][:]
18+
gt = f["labels/az"][:] if version == 3 else f["labels/az_thin"][:]
19+
20+
with h5py.File(seg_path) as f:
21+
seg_key = f"predictions/az/v{version}"
22+
pred = f[seg_key][:]
23+
24+
v = napari.Viewer()
25+
v.add_image(raw)
26+
v.add_image(pred, blending="additive")
27+
v.add_labels(gt)
28+
v.title = f"{name}/{os.path.basename(seg_path)}"
29+
napari.run()
30+
31+
32+
def main():
33+
parser = argparse.ArgumentParser()
34+
parser.add_argument("--version", "-v", type=int, required=True)
35+
parser.add_argument("--split", default="test")
36+
parser.add_argument("--names", nargs="+", default=ALL_NAMES)
37+
args = parser.parse_args()
38+
39+
for name in args.names:
40+
check_predictions(name, args.split, args.version)
41+
42+
43+
if __name__ == "__main__":
44+
main()

scripts/cooper/revision/common.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import json
2+
import os
3+
4+
5+
# The root folder which contains the new AZ training data.
6+
INPUT_ROOT = "/mnt/ceph-hdd/cold_store/projects/nim00007/new_AZ_train_data"
7+
# The output folder for AZ predictions.
8+
OUTPUT_ROOT = "/mnt/ceph-hdd/cold_store/projects/nim00007/AZ_predictions_new"
9+
10+
# The names of all datasets for which to run prediction / evaluation.
11+
# This excludes 'endbulb_of_held_cropped', which is a duplicate of 'endbulb_of_held',
12+
# which we don't evaluate on because of this.
13+
ALL_NAMES = [
14+
"chemical_fixation", "endbulb_of_held", "stem", "stem_cropped", "tem"
15+
]
16+
17+
# The translation of new dataset names to old dataset names.
18+
NAME_TRANSLATION = {
19+
"chemical_fixation": ["12_chemical_fix_cryopreparation_minusSVseg_corrected"],
20+
"endbulb_of_held": ["wichmann_withAZ_rescaled_tomograms"],
21+
"stem": ["04_hoi_stem_examples_fidi_and_sarah_corrected_rescaled_tomograms"],
22+
"stem_cropped": ["04_hoi_stem_examples_minusSVseg_cropped_corrected_rescaled_tomograms",
23+
"06_hoi_wt_stem750_fm_minusSVseg_cropped_corrected_rescaled_tomograms"],
24+
"tem": ["01data_withoutInvertedFiles_minusSVseg_corrected"],
25+
}
26+
27+
28+
# Get the paths to the files with raw data / ground-truth and the segmentation.
29+
def get_paths(name, file_names, skip_seg=False):
30+
seg_paths, gt_paths = [], []
31+
for fname in file_names:
32+
if not skip_seg:
33+
seg_path = os.path.join(OUTPUT_ROOT, name, fname)
34+
assert os.path.exists(seg_path), seg_path
35+
seg_paths.append(seg_path)
36+
37+
gt_path = os.path.join(INPUT_ROOT, name, fname)
38+
assert os.path.exists(gt_path), gt_path
39+
gt_paths.append(gt_path)
40+
41+
return seg_paths, gt_paths
42+
43+
44+
def get_file_names(name, split_folder, split_names):
45+
split_path = os.path.join(split_folder, f"split-{name}.json")
46+
if os.path.exists(split_path):
47+
with open(split_path) as f:
48+
splits = json.load(f)
49+
file_names = [fname for split in split_names for fname in splits[split]]
50+
51+
else:
52+
old_names = NAME_TRANSLATION[name]
53+
file_names = []
54+
for old_name in old_names:
55+
split_path = os.path.join(split_folder, f"split-{old_name}.json")
56+
with open(split_path) as f:
57+
splits = json.load(f)
58+
this_file_names = [fname for split in split_names for fname in splits[split]]
59+
file_names.extend(this_file_names)
60+
return file_names
61+
62+
63+
def get_split_folder(version):
64+
assert version in (3, 4, 5)
65+
if version == 3:
66+
split_folder = "splits"
67+
else:
68+
split_folder = "models_az_thin"
69+
return split_folder

scripts/cooper/revision/eval_AZ.sh

Lines changed: 0 additions & 8 deletions
This file was deleted.

scripts/cooper/revision/evaluate_result.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
args = parser.parse_args()
77

88
results = pd.read_excel(args.result_path)
9+
print(results)
910

1011

1112
def summarize_results(res):
@@ -20,29 +21,29 @@ def summarize_results(res):
2021

2122

2223
# # Compute the results for Chemical Fixation.
23-
results_chem_fix = results[results.dataset.str.startswith("12")]
24+
results_chem_fix = results[results.dataset == "chemical_fixation"]
2425
if results_chem_fix.size > 0:
2526
print("Chemical Fixation Results:")
2627
summarize_results(results_chem_fix)
2728
#
2829
# # Compute the results for STEM (=04).
29-
results_stem = results[results.dataset.str.startswith(("04", "06"))]
30+
results_stem = results[results.dataset.str.startswith("stem")]
3031
if results_stem.size > 0:
3132
print()
3233
print("STEM Results:")
3334
summarize_results(results_stem)
3435
#
3536
# # Compute the results for TEM (=01).
36-
results_tem = results[results.dataset.str.startswith("01")]
37+
results_tem = results[results.dataset == "tem"]
3738
if results_tem.size > 0:
3839
print()
3940
print("TEM Results:")
4041
summarize_results(results_tem)
4142

4243
#
43-
# Compute the results for Wichmann.
44-
results_wichmann = results[results.dataset.str.startswith("wichmann")]
44+
# Compute the results for Wichmann / endbulb of held.
45+
results_wichmann = results[results.dataset.str.startswith("endbulb")]
4546
if results_wichmann.size > 0:
4647
print()
47-
print("Wichmann Results:")
48+
print("Endbulb of Held Results:")
4849
summarize_results(results_wichmann)

scripts/cooper/revision/fix_az.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import os
2+
from glob import glob
3+
import h5py
4+
from tqdm import tqdm
5+
6+
7+
INPUT_ROOT = "/mnt/ceph-hdd/cold_store/projects/nim00007/new_AZ_train_data"
8+
9+
files = glob(os.path.join(INPUT_ROOT, "**/*.h5"), recursive=True)
10+
11+
key = "labels/az_merged"
12+
for ff in tqdm(files):
13+
with h5py.File(ff, "a") as f:
14+
az = f[key][:]
15+
az = az.squeeze()
16+
del f[key]
17+
f.create_dataset(key, data=az, compression="lzf")

scripts/cooper/revision/generate_az_eval_data.py

Lines changed: 0 additions & 31 deletions
This file was deleted.

0 commit comments

Comments
 (0)