Skip to content

Commit 466c733

Browse files
committed
added cristae model and single channel transfrom
1 parent 1f0387a commit 466c733

File tree

3 files changed

+10
-2
lines changed

3 files changed

+10
-2
lines changed

synapse_net/inference/inference.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .mitochondria import segment_mitochondria
1111
from .ribbon_synapse import segment_ribbon_synapse_structures
1212
from .vesicles import segment_vesicles
13+
from .cristae import segment_cristae
1314
from .util import get_device
1415
from ..file_utils import get_cache_dir
1516

@@ -25,6 +26,7 @@ def _get_model_registry():
2526
"compartments": "527983720f9eb215c45c4f4493851fd6551810361eda7b79f185a0d304274ee1",
2627
"mitochondria": "24625018a5968b36f39fa9d73b121a32e8f66d0f2c0540d3df2e1e39b3d58186",
2728
"mitochondria2": "553decafaff4838fff6cc8347f22c8db3dee5bcbeffc34ffaec152f8449af673",
29+
"cristae": "f96c90484f4ea92ac0515a06e389cc117580f02c2aacdc44b5828820cf38c3c3",
2830
"ribbon": "7c947f0ddfabe51a41d9d05c0a6ca7d6b238f43df2af8fffed5552d09bb075a9",
2931
"vesicles_2d": "eb0b74f7000a0e6a25b626078e76a9452019f2d1ea6cf2033073656f4f055df1",
3032
"vesicles_3d": "b329ec1f57f305099c984fbb3d7f6ae4b0ff51ec2fa0fa586df52dad6b84cf29",
@@ -35,6 +37,7 @@ def _get_model_registry():
3537
"compartments": "https://owncloud.gwdg.de/index.php/s/DnFDeTmDDmZrDDX/download",
3638
"mitochondria": "https://owncloud.gwdg.de/index.php/s/1T542uvzfuruahD/download",
3739
"mitochondria2": "https://owncloud.gwdg.de/index.php/s/GZghrXagc54FFXd/download",
40+
"cristae": "https://owncloud.gwdg.de/index.php/s/Df7OUOyQ1Kc2eEO/download",
3841
"ribbon": "https://owncloud.gwdg.de/index.php/s/S3b5l0liPP1XPYA/download",
3942
"vesicles_2d": "https://owncloud.gwdg.de/index.php/s/d72QIvdX6LsgXip/download",
4043
"vesicles_3d": "https://owncloud.gwdg.de/index.php/s/A425mkAOSqePDhx/download",
@@ -214,14 +217,16 @@ def run_segmentation(
214217
"""
215218
if model_type.startswith("vesicles"):
216219
segmentation = segment_vesicles(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
217-
elif model_type == "mitochondria":
220+
elif model_type == "mitochondria" or model_type == "mitochondria2":
218221
segmentation = segment_mitochondria(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
219222
elif model_type == "active_zone":
220223
segmentation = segment_active_zone(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
221224
elif model_type == "compartments":
222225
segmentation = segment_compartments(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
223226
elif model_type == "ribbon":
224227
segmentation = _segment_ribbon_AZ(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
228+
elif model_type == "cristae":
229+
segmentation = segment_cristae(image, model=model, tiling=tiling, scale=scale, verbose=verbose, **kwargs)
225230
else:
226231
raise ValueError(f"Unknown model type: {model_type}")
227232
return segmentation

synapse_net/inference/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def get_prediction(
121121
# If we have channels then the standardization is done independently per channel.
122122
if with_channels:
123123
# TODO Check that this is the correct axis.
124-
input_volume = torch_em.transform.raw.standardize(input_volume, axis=(1, 2, 3))
124+
input_volume = np.stack([torch_em.transform.raw.normalize(input_volume[0]), input_volume[1]], axis=0)
125125
else:
126126
input_volume = torch_em.transform.raw.standardize(input_volume)
127127

synapse_net/tools/segmentation_widget.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,9 @@ def on_predict(self):
178178
if model_type == "ribbon": # Currently only the ribbon model needs the extra seg.
179179
extra_seg = self._get_layer_selector_data(self.extra_seg_selector_name)
180180
kwargs = {"extra_segmentation": extra_seg}
181+
elif model_type == "cristae": # Cristae model expects 2 3D volumes
182+
image = np.stack([image, self._get_layer_selector_data(self.extra_seg_selector_name)], axis=0)
183+
kwargs = {}
181184
else:
182185
kwargs = {}
183186
segmentation = run_segmentation(

0 commit comments

Comments
 (0)