Skip to content

Commit b842bca

Browse files
committed
new data analysis, not tested
1 parent 6525651 commit b842bca

File tree

4 files changed

+538
-0
lines changed

4 files changed

+538
-0
lines changed
Lines changed: 259 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,259 @@
1+
import os
2+
import numpy as np
3+
import h5py
4+
5+
from skimage.measure import regionprops
6+
from skimage.morphology import remove_small_holes
7+
from skimage.segmentation import relabel_sequential
8+
9+
from synapse_net.inference.vesicles import segment_vesicles
10+
from synapse_net.inference.compartments import segment_compartments
11+
from synapse_net.inference.active_zone import segment_active_zone
12+
from synapse_net.inference.inference import get_model_path
13+
14+
15+
def fill_and_filter_vesicles(vesicles: np.ndarray) -> np.ndarray:
16+
"""
17+
Apply a size filter and fill small holes in vesicle segments.
18+
19+
Args:
20+
vesicles (np.ndarray): 3D volume with vesicle segment labels.
21+
22+
Returns:
23+
np.ndarray: Processed vesicle segmentation volume.
24+
"""
25+
ids, sizes = np.unique(vesicles, return_counts=True)
26+
ids, sizes = ids[1:], sizes[1:] # remove background
27+
28+
min_size = 2500
29+
vesicles_pp = vesicles.copy()
30+
filter_ids = ids[sizes < min_size]
31+
vesicles_pp[np.isin(vesicles, filter_ids)] = 0
32+
33+
props = regionprops(vesicles_pp)
34+
for prop in props:
35+
bb = prop.bbox
36+
bb = np.s_[
37+
bb[0]:bb[3], bb[1]:bb[4], bb[2]:bb[5]
38+
]
39+
mask = vesicles_pp[bb] == prop.label
40+
mask = remove_small_holes(mask, area_threshold=1000)
41+
vesicles_pp[bb][mask] = prop.label
42+
43+
return vesicles_pp
44+
45+
46+
def SV_pred(raw: np.ndarray, SV_model: str, output_path: str = None, store: bool = False) -> np.ndarray:
47+
"""
48+
Run synaptic vesicle segmentation and optionally store the output.
49+
50+
Args:
51+
raw (np.ndarray): Raw EM image volume.
52+
SV_model (str): Path to vesicle model.
53+
output_path (str): HDF5 file to store predictions.
54+
store (bool): Whether to store predictions.
55+
56+
Returns:
57+
np.ndarray: Segmentation result.
58+
"""
59+
seg, pred = segment_vesicles(input_volume=raw, model_path=SV_model, verbose=False, return_predictions=True)
60+
61+
if store and output_path:
62+
pred_key = f"predictions/SV/pred"
63+
seg_key = f"predictions/SV/seg"
64+
65+
with h5py.File(output_path, "a") as f:
66+
if pred_key in f:
67+
print(f"{pred_key} already saved")
68+
else:
69+
f.create_dataset(pred_key, data=pred, compression="lzf")
70+
if seg_key in f:
71+
print(f"{seg_key} already saved")
72+
else:
73+
f.create_dataset(seg_key, data=seg, compression="lzf")
74+
elif store and not output_path:
75+
print("Output path is missing, not storing SV predictions")
76+
else:
77+
print("Not storing SV predictions")
78+
79+
return seg
80+
81+
82+
def compartment_pred(raw: np.ndarray, compartment_model: str, output_path: str = None, store: bool = False) -> np.ndarray:
83+
"""
84+
Run compartment segmentation and optionally store the output.
85+
86+
Args:
87+
raw (np.ndarray): Raw EM image volume.
88+
compartment_model (str): Path to compartment model.
89+
output_path (str): HDF5 file to store predictions.
90+
store (bool): Whether to store predictions.
91+
92+
Returns:
93+
np.ndarray: Segmentation result.
94+
"""
95+
seg, pred = segment_compartments(input_volume=raw, model_path=compartment_model, verbose=False, return_predictions=True)
96+
97+
if store and output_path:
98+
pred_key = f"predictions/compartment/pred"
99+
seg_key = f"predictions/compartment/seg"
100+
101+
with h5py.File(output_path, "a") as f:
102+
if pred_key in f:
103+
print(f"{pred_key} already saved")
104+
else:
105+
f.create_dataset(pred_key, data=pred, compression="lzf")
106+
if seg_key in f:
107+
print(f"{seg_key} already saved")
108+
else:
109+
f.create_dataset(seg_key, data=seg, compression="lzf")
110+
elif store and not output_path:
111+
print("Output path is missing, not storing compartment predictions")
112+
else:
113+
print("Not storing compartment predictions")
114+
115+
return seg
116+
117+
118+
def AZ_pred(raw: np.ndarray, AZ_model: str, output_path: str = None, store: bool = False) -> np.ndarray:
119+
"""
120+
Run active zone segmentation and optionally store the output.
121+
122+
Args:
123+
raw (np.ndarray): Raw EM image volume.
124+
AZ_model (str): Path to AZ model.
125+
output_path (str): HDF5 file to store predictions.
126+
store (bool): Whether to store predictions.
127+
128+
Returns:
129+
np.ndarray: Segmentation result.
130+
"""
131+
seg, pred = segment_active_zone(raw, model_path=AZ_model, verbose=False, return_predictions=True)
132+
133+
if store and output_path:
134+
pred_key = f"predictions/az/pred"
135+
seg_key = f"predictions/az/seg"
136+
137+
with h5py.File(output_path, "a") as f:
138+
if pred_key in f:
139+
print(f"{pred_key} already saved")
140+
else:
141+
f.create_dataset(pred_key, data=pred, compression="lzf")
142+
if seg_key in f:
143+
print(f"{seg_key} already saved")
144+
else:
145+
f.create_dataset(seg_key, data=seg, compression="lzf")
146+
elif store and not output_path:
147+
print("Output path is missing, not storing AZ predictions")
148+
else:
149+
print("Not storing AZ predictions")
150+
151+
return seg
152+
153+
154+
def filter_presynaptic_SV(sv_seg: np.ndarray, compartment_seg: np.ndarray, output_path: str = None,
155+
store: bool = False, input_path: str = None) -> np.ndarray:
156+
"""
157+
Filters synaptic vesicle segmentation to retain only vesicles in the presynaptic region.
158+
159+
Args:
160+
sv_seg (np.ndarray): Vesicle segmentation.
161+
compartment_seg (np.ndarray): Compartment segmentation.
162+
output_path (str): Optional HDF5 file to store outputs.
163+
store (bool): Whether to store outputs.
164+
input_path (str): Path to input file (for filename-based filtering).
165+
166+
Returns:
167+
np.ndarray: Filtered presynaptic vesicle segmentation.
168+
"""
169+
# Fill out small holes in vesicles and then apply a size filter.
170+
vesicles_pp = fill_and_filter_vesicles(sv_seg)
171+
172+
def n_vesicles(mask, ves):
173+
return len(np.unique(ves[mask])) - 1
174+
175+
# Find the segment with most vesicles.
176+
props = regionprops(compartment_seg, intensity_image=vesicles_pp, extra_properties=[n_vesicles])
177+
compartment_ids = [prop.label for prop in props]
178+
vesicle_counts = [prop.n_vesicles for prop in props]
179+
if len(compartment_ids) == 0:
180+
mask = np.ones(compartment_seg.shape, dtype="bool")
181+
else:
182+
mask = (compartment_seg == compartment_ids[np.argmax(vesicle_counts)]).astype("uint8")
183+
184+
# Filter all vesicles that are not in the mask.
185+
props = regionprops(vesicles_pp, mask)
186+
filter_ids = [prop.label for prop in props if prop.max_intensity == 0]
187+
188+
name = os.path.basename(input_path) if input_path else "unknown"
189+
print(name)
190+
191+
no_filter = ["C_M13DKO_080212_CTRL6.7B_crop.h5", "E_M13DKO_080212_DKO1.2_crop.h5",
192+
"G_M13DKO_080212_CTRL6.7B_crop.h5", "A_SNAP25_120812_CTRL2.3_14_crop.h5",
193+
"A_SNAP25_12082_KO2.1_6_crop.h5", "B_SNAP25_120812_CTRL2.3_14_crop.h5",
194+
"B_SNAP25_12082_CTRL2.3_5_crop.h5", "D_SNAP25_120812_CTRL2.3_14_crop.h5",
195+
"G_SNAP25_12.08.12_KO1.1_3_crop.h5"]
196+
# Don't filter for wrong masks (visual inspection)
197+
if name not in no_filter:
198+
vesicles_pp[np.isin(vesicles_pp, filter_ids)] = 0
199+
200+
if store and output_path:
201+
seg_presynapse = f"predictions/compartment/presynapse"
202+
seg_presynaptic_SV = f"predictions/SV/presynaptic"
203+
204+
with h5py.File(output_path, "a") as f:
205+
if seg_presynapse in f:
206+
print(f"{seg_presynapse} already saved")
207+
else:
208+
f.create_dataset(seg_presynapse, data=mask, compression="lzf")
209+
if seg_presynaptic_SV in f:
210+
print(f"{seg_presynaptic_SV} already saved")
211+
else:
212+
f.create_dataset(seg_presynaptic_SV, data=vesicles_pp, compression="lzf")
213+
elif store and not output_path:
214+
print("Output path is missing, not storing presynapse seg and presynaptic SV seg")
215+
else:
216+
print("Not storing presynapse seg and presynaptic SV seg")
217+
218+
#All non-zero labels are relabeled starting from 1.Labels are sequential (1, 2, 3, ..., n).
219+
#We do this to make the analysis part easier -> can match distances and diameters better
220+
vesicles_pp, _, _ = relabel_sequential(vesicles_pp)
221+
222+
return vesicles_pp
223+
224+
225+
def run_predictions(input_path: str, output_path: str = None, store: bool = False):
226+
"""
227+
Run full inference pipeline: vesicles, compartments, active zone, and presynaptic SV filtering.
228+
229+
Args:
230+
input_path (str): Path to input HDF5 file with 'raw' dataset.
231+
output_path (str): Path to output HDF5 file to store predictions.
232+
store (bool): Whether to store intermediate and final results.
233+
234+
Returns:
235+
Tuple[np.ndarray, np.ndarray]: (Filtered vesicle segmentation, AZ segmentation)
236+
"""
237+
with h5py.File(input_path, "r") as f:
238+
raw = f["raw"][:]
239+
240+
SV_model = get_model_path("vesicles_3d")
241+
compartment_model = get_model_path("compartments")
242+
# TODO upload better AZ model
243+
AZ_model = "/mnt/lustre-emmy-hdd/usr/u12095/synapse_net/models/ConstantinAZ/checkpoints/v7/"
244+
245+
print("Running SV prediction")
246+
sv_seg = SV_pred(raw, SV_model, output_path, store)
247+
248+
print("Running compartment prediction")
249+
comp_seg = compartment_pred(raw, compartment_model, output_path, store)
250+
251+
print("Running AZ prediction")
252+
az_seg = AZ_pred(raw, AZ_model, output_path, store)
253+
254+
print("Filtering the presynaptic SV")
255+
presyn_SV_seg = filter_presynaptic_SV(sv_seg, comp_seg, output_path, store, input_path)
256+
257+
print("Done with predictions")
258+
259+
return presyn_SV_seg, az_seg
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from synapse_net.distance_measurements import measure_segmentation_to_object_distances
2+
from synapse_net.imod.to_imod import convert_segmentation_to_spheres
3+
4+
5+
def calc_AZ_SV_distance(vesicles, az, resolution):
6+
"""
7+
Calculate the distance between synaptic vesicles (SVs) and the active zone (AZ).
8+
9+
Args:
10+
vesicles (np.ndarray): Segmentation of synaptic vesicles.
11+
az (np.ndarray): Segmentation of the active zone.
12+
resolution (tuple): Voxel resolution in nanometers (z, y, x).
13+
14+
Returns:
15+
list of dict: Each dict contains 'seg_id' and 'distance', sorted by seg_id.
16+
"""
17+
distances, _, _, seg_ids = measure_segmentation_to_object_distances(vesicles, az, resolution=resolution)
18+
19+
dist_list = [{"seg_id": sid, "distance": dist} for sid, dist in zip(seg_ids, distances)]
20+
dist_list.sort(key=lambda x: x["seg_id"])
21+
22+
return dist_list
23+
24+
25+
def sort_by_distances(input_list):
26+
"""
27+
Sort a list of dictionaries by the 'distance' key from smallest to largest.
28+
29+
Args:
30+
input_list (list of dict): List containing 'distance' as a key in each dictionary.
31+
32+
Returns:
33+
list of dict: Sorted list by ascending distance.
34+
"""
35+
sorted_list = sorted(input_list, key=lambda x: x["distance"])
36+
return sorted_list
37+
38+
39+
def combine_lists(list1, list2):
40+
"""
41+
Combine two lists of dictionaries based on the shared 'seg_id' key.
42+
43+
Args:
44+
list1 (list of dict): First list with 'seg_id' key.
45+
list2 (list of dict): Second list with 'seg_id' key.
46+
47+
Returns:
48+
list of dict: Combined dictionaries matching by 'seg_id'. Overlapping keys are merged.
49+
"""
50+
combined_dict = {}
51+
52+
for item in list1:
53+
seg_id = item["seg_id"]
54+
combined_dict[seg_id] = item.copy()
55+
56+
for item in list2:
57+
seg_id = item["seg_id"]
58+
if seg_id in combined_dict:
59+
for key, value in item.items():
60+
if key != "seg_id":
61+
combined_dict[seg_id][key] = value
62+
else:
63+
combined_dict[seg_id] = item.copy()
64+
65+
combined_list = list(combined_dict.values())
66+
return combined_list
67+
68+
69+
def calc_SV_diameters(vesicles, resolution):
70+
"""
71+
Calculate diameters of synaptic vesicles from segmentation data.
72+
73+
Args:
74+
vesicles (np.ndarray): Segmentation of synaptic vesicles.
75+
resolution (tuple): Voxel resolution in nanometers (z, y, x).
76+
77+
Returns:
78+
list of dict: Each dict contains 'seg_id' and 'diameter', sorted by seg_id.
79+
"""
80+
coordinates, radii = convert_segmentation_to_spheres(
81+
vesicles, resolution=resolution, radius_factor=0.7, estimate_radius_2d=True
82+
)
83+
84+
# Assuming the segment ID is the index of the vesicle (same order as radii)
85+
seg_ids = list(range(len(radii)))
86+
radii_nm = radii * resolution[0]
87+
diameters = radii_nm * 2
88+
89+
diam_list = [{"seg_id": sid, "diameter": diam} for sid, diam in zip(seg_ids, diameters)]
90+
diam_list.sort(key=lambda x: x["seg_id"])
91+
92+
return diam_list

0 commit comments

Comments
 (0)