Skip to content

Commit e3a46c9

Browse files
committed
changes to semisupervised_training.py domain_adaptation.py
1 parent a85f3b8 commit e3a46c9

File tree

4 files changed

+336
-53
lines changed

4 files changed

+336
-53
lines changed
Lines changed: 60 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,82 @@
11
import os
22
from glob import glob
33
from pathlib import Path
4+
from typing import Optional
45

56
import h5py
67
import numpy as np
78
from elf.io import open_file
9+
from synapse_net.training.supervised_training import get_3d_model
810
from synapse_net.inference.actin import segment_actin
11+
import torch_em
12+
import torch
913

14+
def predict_actin(input_dir, model_path, output_dir, device: int=0, torch_load: bool=False, state_key: Optional[str]=None):
15+
input_dir = Path(input_dir)
16+
output_dir = Path(output_dir)
17+
output_dir.mkdir(parents=True, exist_ok=True)
1018

11-
# Run prediction on the actin val volume.
12-
def predict_actin_val():
13-
path = "/mnt/lustre-grete/usr/u12086/data/deepict/deepict_actin/00012.h5"
19+
model_path = Path(model_path)
20+
model_name = model_path.stem
1421

15-
# This is the validation ROI.
16-
roi = np.s_[250:, :, :]
17-
with h5py.File(path, "r") as f:
18-
raw = f["raw"][roi]
22+
if torch_load:
23+
ckpt = str(model_path / "best.pt")
24+
x = torch.load(ckpt, map_location=f"cuda:{device}", weights_only=False)
25+
model = get_3d_model(out_channels=2)
26+
if state_key is None:
27+
state_key = "model_state"
28+
model.load_state_dict(x[state_key])
29+
else:
30+
model = torch_em.util.load_model(str(model_path), device=f"cuda:{device}")
1931

20-
model_path = "./checkpoints/actin-deepict"
21-
seg, pred = segment_actin(raw, model_path, verbose=True, return_predictions=True)
32+
for data_path in input_dir.glob("*.h5"):
33+
with h5py.File(data_path, "r") as f:
34+
raw = f["raw"][:]
35+
labels = f["labels/actin"][:]
2236

23-
with h5py.File("actin_pred.h5", "a") as f:
24-
f.create_dataset("raw", data=raw, compression="gzip")
25-
f.create_dataset("actin_seg", data=seg, compression="gzip")
26-
f.create_dataset("actin_pred", data=pred, compression="gzip")
37+
seg, pred = segment_actin(raw, model=model, verbose=True, return_predictions=True)
2738

39+
output_path = output_dir / f"{data_path.stem}.h5"
2840

29-
def predict_actin_fb():
30-
root = "/mnt/lustre-emmy-hdd/projects/nim00007/data/synaptic-reconstruction/fernandez-busnadiego/from_arsen/tomos_actin_18924" # noqa
31-
files = glob(os.path.join(root, "*.mrc"))
41+
print(f"Writing prediction to {output_path}.")
42+
with h5py.File(output_path, "a") as f:
43+
if "raw" not in f:
44+
f.create_dataset("raw", data=raw, compression="gzip")
45+
if "labels/actin" not in f:
46+
f.create_dataset("labels/actin", data=labels, compression="gzip")
47+
f.create_dataset(f"predictions/{model_name}", data=pred, compression="gzip")
48+
f.create_dataset(f"segmentations/{model_name}", data=seg, compression="gzip")
3249

33-
model_path = "./checkpoints/actin-adapted-v1"
34-
35-
for ff in files:
36-
print("Predict", ff)
37-
with open_file(ff, "r") as f:
38-
raw = f["data"][:]
39-
seg, pred = segment_actin(raw, model_path, verbose=True, return_predictions=True)
40-
41-
out_path = f"{Path(ff).stem}.h5"
42-
with h5py.File(out_path, "a") as f:
43-
# f.create_dataset("raw", data=raw, compression="gzip")
44-
f.create_dataset("actin_seg", data=seg, compression="gzip")
45-
f.create_dataset("actin_pred", data=pred, compression="gzip")
50+
def main():
51+
MODEL_DIR = Path("/mnt/data1/sage/synapse-net/scripts/cryo/actin/output")
52+
PRED_DIR = Path("/mnt/data1/sage/synapse-net/scripts/cryo/actin/predictions")
4653

54+
predict_actin(
55+
input_dir = "/mnt/data1/sage/actin-segmentation/data/deepict/deepict_actin/test",
56+
model_path = MODEL_DIR / "experiment2/run3/checkpoints/actin-adapted-opto2deepict-v2",
57+
output_dir = PRED_DIR / "deepict",
58+
device = 3,
59+
torch_load=True,
60+
state_key="teacher_state"
61+
)
4762

48-
def main():
49-
# predict_actin_val()
50-
predict_actin_fb()
63+
predict_actin(
64+
input_dir = "/mnt/data1/sage/actin-segmentation/data/deepict/deepict_actin/test",
65+
model_path = MODEL_DIR / "experiment1/run1/checkpoints/actin-deepict-v3",
66+
output_dir = PRED_DIR / "deepict",
67+
device = 3,
68+
torch_load=True,
69+
state_key="model_state"
70+
)
5171

72+
predict_actin(
73+
input_dir = "/mnt/data1/sage/actin-segmentation/data/EMPIAR-12292/h5/test",
74+
model_path = MODEL_DIR / "experiment1/run3/checkpoints/actin-adapted-deepict2opto-v2",
75+
output_dir = PRED_DIR / "opto",
76+
device = 3,
77+
torch_load=True,
78+
state_key="teacher_state"
79+
)
5280

5381
if __name__ == "__main__":
5482
main()

scripts/cryo/actin/surface_dice.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
#!/bin/env python3
2+
import sys
3+
import os
4+
5+
# Add membrain-seg to Python path
6+
MEMBRAIN_SEG_PATH = "/home/sage/membrain-seg/src"
7+
if MEMBRAIN_SEG_PATH not in sys.path:
8+
sys.path.insert(0, MEMBRAIN_SEG_PATH)
9+
10+
import argparse
11+
import h5py
12+
import pandas as pd
13+
from tqdm import tqdm
14+
import numpy as np
15+
from scipy.ndimage import label
16+
from skimage.measure import regionprops
17+
18+
try:
19+
from membrain_seg.segmentation.skeletonize import skeletonization
20+
from membrain_seg.benchmark.metrics import masked_surface_dice
21+
except ImportError:
22+
raise ImportError("membrain_seg not found in path. Download source code:" \
23+
"https://github.com/teamtomo/membrain-seg/tree/main/src/membrain_seg")
24+
exit()
25+
26+
def load_segmentation(file_path, key):
27+
with h5py.File(file_path, "r") as f:
28+
data = f[key][:]
29+
return data
30+
31+
def evaluate_surface_dice(pred, gt, raw, check):
32+
gt_skeleton = skeletonization(gt == 1, batch_size=100000)
33+
pred_skeleton = skeletonization(pred, batch_size=100000)
34+
mask = gt != 2
35+
36+
if check:
37+
import napari
38+
v = napari.Viewer()
39+
v.add_image(raw)
40+
v.add_labels(gt, name="gt")
41+
v.add_labels(gt_skeleton.astype(np.uint16), name="gt_skeleton")
42+
v.add_labels(pred, name="pred")
43+
v.add_labels(pred_skeleton.astype(np.uint16), name="pred_skeleton")
44+
45+
napari.run()
46+
47+
surf_dice, confusion_dict = masked_surface_dice(
48+
pred_skeleton, gt_skeleton, pred, gt, mask
49+
)
50+
return surf_dice, confusion_dict
51+
52+
53+
def process_file(pred_path, gt_path, seg_key, gt_key, check,
54+
min_bb_shape=(64, 384, 384), min_thinning_size=2500,
55+
global_eval=False):
56+
try:
57+
pred = load_segmentation(pred_path, seg_key)
58+
gt = load_segmentation(gt_path, gt_key)
59+
raw = load_segmentation(gt_path, "raw")
60+
61+
if global_eval:
62+
gt_bin = (gt == 1).astype(np.uint8)
63+
pred_bin = pred.astype(np.uint8)
64+
65+
dice, confusion = evaluate_surface_dice(pred_bin, gt_bin, raw, check)
66+
return [{
67+
"tomo_name": os.path.basename(pred_path),
68+
"gt_component_id": -1, # -1 indicates global eval
69+
"surface_dice": dice,
70+
**confusion
71+
}]
72+
73+
labeled_gt, _ = label(gt == 1)
74+
props = regionprops(labeled_gt)
75+
results = []
76+
77+
for prop in props:
78+
if prop.area < min_thinning_size:
79+
continue
80+
81+
comp_id = prop.label
82+
bbox_start = prop.bbox[:3]
83+
bbox_end = prop.bbox[3:]
84+
bbox = tuple(slice(start, stop) for start, stop in zip(bbox_start, bbox_end))
85+
86+
pad_width = [
87+
max(min_shape - (sl.stop - sl.start), 0) // 2
88+
for sl, min_shape in zip(bbox, min_bb_shape)
89+
]
90+
91+
expanded_bbox = tuple(
92+
slice(
93+
max(sl.start - pw, 0),
94+
min(sl.stop + pw, dim)
95+
)
96+
for sl, pw, dim in zip(bbox, pad_width, gt.shape)
97+
)
98+
99+
gt_crop = (labeled_gt[expanded_bbox] == comp_id).astype(np.uint8)
100+
pred_crop = pred[expanded_bbox].astype(np.uint8)
101+
raw_crop = raw[expanded_bbox]
102+
103+
try:
104+
dice, confusion = evaluate_surface_dice(pred_crop, gt_crop, raw_crop, check)
105+
except Exception as e:
106+
print(f"Error computing Dice for GT component {comp_id} in {pred_path}: {e}")
107+
continue
108+
109+
result = {
110+
"tomo_name": os.path.basename(pred_path),
111+
"gt_component_id": comp_id,
112+
"surface_dice": dice,
113+
**confusion
114+
}
115+
results.append(result)
116+
117+
return results
118+
119+
except Exception as e:
120+
print(f"Error processing {pred_path}: {e}")
121+
return []
122+
123+
124+
def collect_results(input_folder, gt_folder, model_name, check=False,
125+
min_bb_shape=(32, 384, 384), min_thinning_size=2500,
126+
global_eval=False):
127+
results = []
128+
seg_key = f"/segmentations/{model_name}"
129+
gt_key = "/labels/actin"
130+
input_folder_name = os.path.basename(os.path.normpath(input_folder))
131+
132+
for fname in tqdm(os.listdir(input_folder), desc="Processing segmentations"):
133+
if not fname.endswith(".h5"):
134+
continue
135+
136+
pred_path = os.path.join(input_folder, fname)
137+
print(pred_path)
138+
gt_path = os.path.join(gt_folder, fname)
139+
140+
if not os.path.exists(gt_path):
141+
print(f"Warning: Ground truth file not found for {fname}")
142+
continue
143+
144+
file_results = process_file(
145+
pred_path, gt_path, seg_key, gt_key, check,
146+
min_bb_shape=min_bb_shape,
147+
min_thinning_size=min_thinning_size,
148+
global_eval=global_eval
149+
)
150+
151+
for res in file_results:
152+
res["input_folder"] = input_folder_name
153+
results.append(res)
154+
155+
return results
156+
157+
158+
def save_results(results, output_file):
159+
new_df = pd.DataFrame(results)
160+
161+
if os.path.exists(output_file):
162+
existing_df = pd.read_excel(output_file)
163+
164+
combined_df = existing_df[
165+
~existing_df.set_index(["tomo_name", "input_folder", "gt_component_id"]).index.isin(
166+
new_df.set_index(["tomo_name", "input_folder", "gt_component_id"]).index
167+
)
168+
]
169+
170+
final_df = pd.concat([combined_df, new_df], ignore_index=True)
171+
else:
172+
final_df = new_df
173+
174+
final_df.to_excel(output_file, index=False)
175+
print(f"Results saved to {output_file}")
176+
177+
178+
def main():
179+
parser = argparse.ArgumentParser(description="Compute surface dice per GT component or globally for actin segmentations.")
180+
parser.add_argument("--input_folder", "-i", required=True, help="Folder with predicted segmentations (.h5)")
181+
parser.add_argument("--gt_folder", "-gt", required=True, help="Folder with ground truth segmentations (.h5)")
182+
parser.add_argument("--model_name", "-m", required=True, help="Model name string used in prediction key")
183+
parser.add_argument("--check", action="store_true", help="Visualize intermediate outputs in Napari")
184+
parser.add_argument("--global_eval", action="store_true", help="If set, compute global surface dice instead of per-component")
185+
186+
args = parser.parse_args()
187+
188+
min_bb_shape = (32, 464, 464)
189+
min_thinning_size = 2500
190+
191+
suffix = "global" if args.global_eval else "per_gt_component"
192+
193+
output_file = f"./evaluation_results/{args.model_name}_surface_dice_{suffix}.xlsx"
194+
output_dir = os.path.dirname(output_file)
195+
os.makedirs(output_dir, exist_ok=True)
196+
197+
results = collect_results(
198+
args.input_folder,
199+
args.gt_folder,
200+
args.model_name,
201+
args.check,
202+
min_bb_shape=min_bb_shape,
203+
min_thinning_size=min_thinning_size,
204+
global_eval=args.global_eval
205+
)
206+
207+
save_results(results, output_file)
208+
209+
210+
if __name__ == "__main__":
211+
main()

synapse_net/training/domain_adaptation.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,9 @@ def _train_epoch_unsupervised(self, progress, forward_context, backprop):
110110
# Sample from both the supervised and unsupervised loader.
111111
for xu1, xu2 in self.unsupervised_train_loader:
112112

113-
# Assuming shape (B, C, D, H, W), only keep the first channel for xu2 (student input).
113+
# Keep only the first channel for xu2 (student input).
114+
if xu2.ndim != 5:
115+
raise ValueError(f"Expect xu2 to have 5 dimensions (B, C, D, H, W), got shape {xu2.shape}.")
114116
if xu2.shape[1] > 1:
115117
xu2 = xu2[:, :1].contiguous()
116118

@@ -123,6 +125,8 @@ def _train_epoch_unsupervised(self, progress, forward_context, backprop):
123125
pseudo_labels, label_filter = self.pseudo_labeler(self.teacher, teacher_input)
124126

125127
# Drop the second channel for xu1 (teacher input) after computing the pseudo labels.
128+
if xu1.ndim != 5:
129+
raise ValueError(f"Expect xu1 to have 5 dimensions (B, C, D, H, W), got shape {xu1.shape}.")
126130
if xu1.shape[1] > 1:
127131
xu1 = xu1[:, :1].contiguous()
128132

@@ -184,7 +188,7 @@ def mean_teacher_adaptation(
184188
train_background_mask_paths: Optional[Tuple[str]] = None,
185189
patch_sampler: Optional[callable] = None,
186190
pseudo_label_sampler: Optional[callable] = None,
187-
device: int = 0,
191+
device: Optional[torch.device] = None,
188192
) -> None:
189193
"""Run domain adapation to transfer a network trained on a source domain for a supervised
190194
segmentation task to perform this task on a different target domain.
@@ -197,11 +201,9 @@ def mean_teacher_adaptation(
197201
198202
Args:
199203
name: The name for the checkpoint to be trained.
200-
unsupervsied_train_paths: Filepaths to the hdf5 files or similar file formats
201-
for the training data in the target domain.
204+
unsupervsied_train_paths: Filepaths to the hdf5 or mrc files for the training data in the target domain.
202205
This training data is used for unsupervised learning, so it does not require labels.
203-
unsupervised_val_paths: Filepaths to the hdf5 files or similar file formats
204-
for the validation data in the target domain.
206+
unsupervised_val_paths: Filepaths to the hdf5 or mrc files for the validation data in the target domain.
205207
This validation data is used for unsupervised learning, so it does not require labels.
206208
patch_shape: The patch shape used for a training example.
207209
In order to run 2d training pass a patch shape with a singleton in the z-axis,
@@ -231,9 +233,9 @@ def mean_teacher_adaptation(
231233
based on the patch_shape and size of the volumes used for validation.
232234
train_sample_mask_paths: Filepaths to the sample masks used by the patch sampler to accept or reject
233235
patches for training.
234-
val_sample_mask_paths: Filepaths to the sample masks used by the patch sampler to accept or reject
236+
val_sample_mask_paths: Filepaths to the sample masks mrc files used by the patch sampler to accept or reject
235237
patches for validation.
236-
train_background_mask_paths: Filepaths to the background masks used for training.
238+
train_background_mask_paths: Filepaths to the background masks mrc files used for training.
237239
Background masks are used to subtract background from the pseudo labels before the forward pass.
238240
patch_sampler: A sampler for rejecting patches based on a defined conditon.
239241
pseudo_label_sampler: A sampler for rejecting pseudo-labels based on a defined condition.

0 commit comments

Comments
 (0)