Skip to content

Commit a85f3b8

Browse files
committed
scripts and stuff i forgot in the last PR
1 parent 068a90b commit a85f3b8

File tree

6 files changed

+347
-37
lines changed

6 files changed

+347
-37
lines changed
Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
import numpy as np
2+
3+
from synapse_net.file_utils import read_mrc
4+
import mrcfile
5+
import h5py
6+
import tifffile
7+
from pathlib import Path
8+
import re
9+
10+
def apply_ignore_label(h5_path, mask_path, ignore_label: int=-1):
11+
"""For supervised training: set masked voxels to -1 (ignore_label)."""
12+
with h5py.File(h5_path, "r") as f:
13+
raw = f["raw"][:]
14+
labels = f["labels/actin"][:]
15+
16+
with mrcfile.open(mask_path, permissive=True) as mrc:
17+
ignore_mask = mrc.data.astype(bool)
18+
#ignore_mask = np.flip(ignore_mask, axis=1)
19+
20+
labels_masked = labels.astype(np.int32) # ensure signed int type
21+
labels_masked[(labels == 0) & ignore_mask] = ignore_label
22+
23+
out_dir = Path(h5_path).parent / "ignore_label"
24+
out_dir.mkdir(parents=True, exist_ok=True)
25+
fstem = Path(h5_path).stem
26+
out_path = out_dir / f"{fstem}.h5"
27+
28+
print(f"Writing out h5 file with masked labels to {out_path}.")
29+
with h5py.File(out_path, "w") as f:
30+
f.create_dataset("raw", data=raw, compression="gzip")
31+
f.create_dataset("/labels/actin", data=labels_masked, compression="gzip")
32+
33+
def convert_tiff2mrc(in_dir, pixel_size, out_dir=None):
34+
"""Batch convert tiff files to mrc."""
35+
in_dir = Path(in_dir)
36+
37+
if out_dir == None:
38+
out_dir = in_dir
39+
else:
40+
out_dir = Path(out_dir)
41+
out_dir.mkdir(parents=True, exist_ok=True)
42+
43+
path_list = [str(p) for p in in_dir.glob("*.tif")]
44+
45+
for path in path_list:
46+
data = tifffile.imread(path)
47+
data = np.flip(data, axis=1)
48+
filename = Path(path).stem
49+
out_path = out_dir / f"{filename}.mrc"
50+
51+
print(f"Writing out mrc file to {out_path}.")
52+
with mrcfile.new(out_path, overwrite=True) as mrc:
53+
mrc.set_data(data.astype(np.uint8))
54+
mrc.voxel_size = (pixel_size, pixel_size, pixel_size)
55+
56+
def h5_split_tomograms(h5_path, z_range):
57+
"""
58+
Split paired raw and label data (z,y,x) into 8 non-overlapping subvolumes
59+
by cutting it in half along each axis.
60+
"""
61+
with h5py.File(h5_path, "r") as f:
62+
z0, z1 = z_range
63+
raw = f["raw"][z0:z1, :, :]
64+
labels = f["labels/actin"][z0:z1, :, :]
65+
66+
z, y, x = raw.shape
67+
68+
# Compute midpoints
69+
z_mid, y_mid, x_mid = z // 2, y // 2, x // 2
70+
71+
# Define ranges for each half
72+
z_ranges = [(0, z_mid), (z_mid, z)]
73+
y_ranges = [(0, y_mid), (y_mid, y)]
74+
x_ranges = [(0, x_mid), (x_mid, x)]
75+
76+
raw_subvols, label_subvols = [], []
77+
78+
for zi, (z0, z1) in enumerate(z_ranges):
79+
for yi, (y0, y1) in enumerate(y_ranges):
80+
for xi, (x0, x1) in enumerate(x_ranges):
81+
raw_subvol = raw[z0:z1, y0:y1, x0:x1]
82+
label_subvol = labels[z0:z1, y0:y1, x0:x1]
83+
raw_subvols.append(raw_subvol)
84+
label_subvols.append(label_subvol)
85+
86+
return raw_subvols, label_subvols
87+
88+
def write_h5(raw_path, label_path, out_path):
89+
"""Write the raw and labels to an HDF5 file."""
90+
if out_path.exists():
91+
print(f"File {out_path} already exists, skipping.")
92+
return
93+
94+
raw = read_mrc(raw_path)[0]
95+
labels = read_mrc(label_path)[0]
96+
97+
print(f"Writing file to {out_path}.")
98+
with h5py.File(out_path, "w") as f:
99+
f.create_dataset("raw", data=raw, compression="gzip")
100+
f.create_dataset("/labels/actin", data=labels, compression="gzip")
101+
102+
def write_h5_deepict():
103+
PARENT_DIR = Path("/mnt/data1/sage/actin-segmentation/data/deepict/deepict_actin/ignore_label")
104+
TRAIN_DIR = PARENT_DIR / "train"
105+
VAL_DIR = PARENT_DIR / "val"
106+
TEST_DIR = PARENT_DIR / "test"
107+
108+
TRAIN_DIR.mkdir(exist_ok=True)
109+
VAL_DIR.mkdir(exist_ok=True)
110+
TEST_DIR.mkdir(exist_ok=True)
111+
112+
raw_subvols1, label_subvols1 = h5_split_tomograms(
113+
Path(PARENT_DIR / "00004_cleaned.h5"), z_range = (326, 464)
114+
)
115+
raw_subvols2, label_subvols2 = h5_split_tomograms(
116+
Path(PARENT_DIR / "00012_cleaned.h5"), z_range = (147, 349)
117+
)
118+
119+
raw_subvols = raw_subvols1 + raw_subvols2
120+
label_subvols = label_subvols1 + label_subvols2
121+
122+
# predefined indices for train, val, test (10:2:4)
123+
train_idx = [0, 3, 4, 7, 8, 9, 10, 11, 12, 15]
124+
val_idx = [6, 14]
125+
test_idx = [1, 2, 5, 13]
126+
127+
def write_split(idx_list, folder, prefix):
128+
for idx in idx_list:
129+
raw = raw_subvols[idx]
130+
labels = label_subvols[idx]
131+
132+
# tomogram 00004: indices 0-7 -> A
133+
# tomogram 00012: indices 8-15 -> B
134+
if idx < 8:
135+
tag = f"A{idx}"
136+
else:
137+
tag = f"B{idx - 8}"
138+
out_path = folder / f"{prefix}_{tag}.h5"
139+
140+
print(f"Writing file to {out_path}.")
141+
with h5py.File(out_path, "w") as f:
142+
f.create_dataset("raw", data=raw, compression="gzip")
143+
f.create_dataset("/labels/actin", data=labels, compression="gzip")
144+
145+
write_split(train_idx, TRAIN_DIR, "train")
146+
write_split(val_idx, VAL_DIR, "val")
147+
write_split(test_idx, TEST_DIR, "test")
148+
print("\n Finished writing all subvolumes.")
149+
150+
def write_h5_optogenetics():
151+
RAW_DIR = Path("/mnt/data1/sage/actin-segmentation/data/EMPIAR-12292/tomos/")
152+
LABEL_DIR = Path("/mnt/data1/sage/actin-segmentation/data/EMPIAR-12292/labels/")
153+
OUT_DIR = Path("/mnt/data1/sage/actin-segmentation/data/EMPIAR-12292/h5/")
154+
155+
raw_paths = {re.sub('_rec', '', f.stem): f for f in RAW_DIR.glob("*_rec.mrc")}
156+
label_paths = {re.sub('_mask', '', f.stem): f for f in LABEL_DIR.glob("*_mask.mrc")}
157+
158+
stems = raw_paths.keys() | label_paths.keys()
159+
160+
for stem in stems:
161+
if stem not in raw_paths:
162+
print(f"Warning: Missing tomo file for {stem}.")
163+
continue
164+
165+
if stem not in label_paths:
166+
print(f"Warning: Missing label file for {stem}.")
167+
continue
168+
169+
raw_path = raw_paths[stem]
170+
label_path = label_paths[stem]
171+
out_path = OUT_DIR / f"{stem}.h5"
172+
write_h5(raw_path, label_path, out_path)
173+
174+
def main():
175+
#write_h5_optogenetics()
176+
#write_h5_deepict()
177+
#convert_tiff2mrc(
178+
# input_dir = "/mnt/data1/sage/actin-segmentation/data/deepict/deepict_actin/background_masks",
179+
# pixel_size = 13.48
180+
#)
181+
182+
183+
# apply ignore label for masking background during supervised training
184+
PARENT_DIR = Path("/mnt/data1/sage/actin-segmentation/data/deepict/deepict_actin/")
185+
MASK_DIR = PARENT_DIR / "background_masks"
186+
h5_paths = [PARENT_DIR / "00004_cleaned.h5", PARENT_DIR / "00012_cleaned.h5"]
187+
mask_paths = [MASK_DIR / "00004.mrc", MASK_DIR / "00012.mrc"]
188+
189+
for i, (path1, path2) in enumerate(zip(h5_paths, mask_paths)):
190+
apply_ignore_label(path1, path2)
191+
192+
write_h5_deepict()
193+
194+
if __name__ == "__main__":
195+
main()

scripts/cryo/actin/train_actin_segmentation.py

Lines changed: 74 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,101 @@
11
import numpy as np
22

3+
import torch_em
34
from synapse_net.training import supervised_training
45
from torch_em.data.sampler import MinForegroundSampler
56

7+
import os
8+
from pathlib import Path
69

7-
def train_actin_deepict():
10+
def train_actin_deepict_v3():
811
"""Train a network for actin segmentation on the deepict dataset.
12+
Tomograms are split into subvolumes, which are assigned to train, val, or test sets.
913
"""
14+
PARENT_DIR = Path("/mnt/data1/sage/actin-segmentation/data/deepict/deepict_actin/")
15+
TRAIN_DIR = PARENT_DIR / "train"
16+
VAL_DIR = PARENT_DIR / "val"
1017

11-
train_paths = [
12-
"/mnt/lustre-grete/usr/u12086/data/deepict/deepict_actin/00004.h5",
13-
"/mnt/lustre-grete/usr/u12086/data/deepict/deepict_actin/00012.h5",
14-
]
15-
val_paths = [
16-
"/mnt/lustre-grete/usr/u12086/data/deepict/deepict_actin/00012.h5",
17-
]
18+
train_paths = [str(p) for p in TRAIN_DIR.glob("*.h5")]
19+
val_paths = [str(p) for p in VAL_DIR.glob("*.h5")]
1820

19-
train_rois = [np.s_[:, :, :], np.s_[:250, :, :]]
20-
val_rois = [np.s_[250:, :, :]]
21+
patch_shape = (64, 384, 384)
22+
sampler = MinForegroundSampler(min_fraction=0.025, p_reject=0.95)
23+
24+
supervised_training(
25+
name="actin-deepict-v3",
26+
label_key="/labels/actin",
27+
patch_shape=patch_shape,
28+
train_paths=train_paths,
29+
val_paths=val_paths,
30+
n_iterations=int(25000),
31+
sampler=sampler,
32+
out_channels=2,
33+
add_boundary_transform=True,
34+
save_root="./output/experiment1/run1",
35+
check=False,
36+
device=0
37+
)
38+
39+
def train_actin_deepict_v4():
40+
"""Train a network for actin segmentation on the deepict dataset.
41+
Same as v3, with ignore_label to mask background voxels from loss.
42+
"""
43+
PARENT_DIR = Path("/mnt/data1/sage/actin-segmentation/data/deepict/deepict_actin/")
44+
TRAIN_DIR = PARENT_DIR / "train"
45+
VAL_DIR = PARENT_DIR / "val"
46+
47+
train_paths = [str(p) for p in TRAIN_DIR.glob("*.h5")]
48+
val_paths = [str(p) for p in VAL_DIR.glob("*.h5")]
2149

2250
patch_shape = (64, 384, 384)
2351
sampler = MinForegroundSampler(min_fraction=0.025, p_reject=0.95)
2452

2553
supervised_training(
26-
name="actin-deepict",
54+
name="actin-deepict-v4",
2755
label_key="/labels/actin",
2856
patch_shape=patch_shape,
2957
train_paths=train_paths,
3058
val_paths=val_paths,
31-
train_rois=train_rois,
32-
val_rois=val_rois,
59+
n_iterations=int(25000),
3360
sampler=sampler,
34-
save_root=".",
61+
out_channels=2,
62+
add_boundary_transform=True,
63+
save_root="./output/experiment1/run3",
64+
ignore_label=-1,
65+
check=False,
66+
device=4
3567
)
3668

69+
def train_actin_optogenetics():
70+
"""Train a network for actin segmentation on the EMPIAR-12292 dataset.
71+
"""
72+
PARENT_DIR = Path("/mnt/data1/sage/actin-segmentation/data/EMPIAR-12292/h5")
73+
74+
all_paths = [str(p) for p in PARENT_DIR.glob("*.h5")]
75+
train_paths = all_paths[:10]
76+
val_paths = all_paths[10:12]
77+
78+
patch_shape = (64, 384, 384)
79+
sampler = MinForegroundSampler(min_fraction=0.025, p_reject=0.95)
80+
81+
supervised_training(
82+
name="actin-opto-v1",
83+
label_key="/labels/actin",
84+
patch_shape=patch_shape,
85+
train_paths=train_paths,
86+
val_paths=val_paths,
87+
n_iterations=int(25000),
88+
sampler=sampler,
89+
out_channels=2,
90+
add_boundary_transform=True,
91+
save_root="./output/experiment2/run1",
92+
check=False,
93+
device=1
94+
)
3795

3896
def main():
39-
train_actin_deepict()
97+
train_actin_deepict_v4()
98+
#train_actin_optogenetics()
4099

41100

42101
if __name__ == "__main__":

0 commit comments

Comments
 (0)