99from ..inference .active_zone import segment_active_zone
1010from ..inference .compartments import segment_compartments
1111from ..inference .mitochondria import segment_mitochondria
12+ from ..inference .ribbon_synapse import segment_ribbon_synapse_structures
1213from ..inference .vesicles import segment_vesicles
1314
1415
@@ -43,8 +44,8 @@ def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None
4344 """Get the model for the given segmentation type.
4445
4546 Args:
46- model_type: The model type.
47- One of 'vesicles ', 'mitochondria', 'active_zone ', 'compartments' or 'inner_ear_structures '.
47+ model_type: The model type. You can choose One of:
48+ 'vesicles_3d', 'active_zone', 'compartments ', 'mitochondria', 'ribbon ', 'vesicles_2d', 'vesicles_cryo '.
4849 device: The device to use.
4950
5051 Returns:
@@ -58,6 +59,44 @@ def get_model(model_type: str, device: Optional[Union[str, torch.device]] = None
5859 return model
5960
6061
62+ def _segment_ribbon_AZ (image , model , tiling , scale , verbose , ** kwargs ):
63+ # Parse additional keyword arguments from the kwargs.
64+ vesicles = kwargs .pop ("extra_segmentation" )
65+ threshold = kwargs .pop ("threshold" , 0.5 )
66+ n_slices_exclude = kwargs .pop ("n_slices_exclude" , 20 )
67+ n_ribbons = kwargs .pop ("n_slices_exclude" , 1 )
68+
69+ predictions = segment_ribbon_synapse_structures (
70+ image , model = model , tiling = tiling , scale = scale , verbose = verbose , threshold = threshold , ** kwargs
71+ )
72+
73+ # If the vesicles were passed then run additional post-processing.
74+ if vesicles is None :
75+ from synaptic_reconstruction .inference .postprocessing import (
76+ segment_ribbon , segment_presynaptic_density , segment_membrane_distance_based ,
77+ )
78+
79+ ribbon = segment_ribbon (
80+ predictions ["ribbon" ], vesicles , n_slices_exclude = n_slices_exclude , n_ribbons = n_ribbons ,
81+ max_vesicle_distance = 40 ,
82+ )
83+ PD = segment_presynaptic_density (
84+ predictions ["PD" ], ribbon , n_slices_exclude = n_slices_exclude , max_distance_to_ribbon = 40 ,
85+ )
86+ ref_segmentation = PD if PD .sum () > 0 else ribbon
87+ membrane = segment_membrane_distance_based (
88+ predictions ["membrane" ], ref_segmentation , n_sclices_exclude = n_slices_exclude , max_distance = 500
89+ )
90+
91+ segmentation = {"ribbon" : ribbon , "PD" : PD , "membrane" : membrane }
92+
93+ # Otherwise, just return the predictions.
94+ else :
95+ segmentation = predictions
96+
97+ return segmentation
98+
99+
61100def run_segmentation (
62101 image : np .ndarray ,
63102 model : torch .nn .Module ,
@@ -66,22 +105,21 @@ def run_segmentation(
66105 scale : Optional [List [float ]] = None ,
67106 verbose : bool = False ,
68107 ** kwargs ,
69- ) -> np .ndarray :
108+ ) -> np .ndarray | Dict [ str , np . ndarray ] :
70109 """Run synaptic structure segmentation.
71110
72111 Args:
73112 image: The input image or image volume.
74113 model: The segmentation model.
75- model_type: The model type. This will determine which segmentation
76- post-processing is used.
114+ model_type: The model type. This will determine which segmentation post-processing is used.
77115 tiling: The tiling settings for inference.
78116 scale: A scale factor for resizing the input before applying the model.
79117 The output will be scaled back to the initial size.
80118 verbose: Whether to print detailed information about the prediction and segmentation.
81- kwargs: Optional parameter for the segmentation function.
119+ kwargs: Optional parameters for the segmentation function.
82120
83121 Returns:
84- The segmentation.
122+ The segmentation. For models that return multiple segmentations, this function returns a dictionary.
85123 """
86124 if model_type .startswith ("vesicles" ):
87125 segmentation = segment_vesicles (image , model = model , tiling = tiling , scale = scale , verbose = verbose , ** kwargs )
@@ -91,8 +129,8 @@ def run_segmentation(
91129 segmentation = segment_active_zone (image , model = model , tiling = tiling , scale = scale , verbose = verbose , ** kwargs )
92130 elif model_type == "compartments" :
93131 segmentation = segment_compartments (image , model = model , tiling = tiling , scale = scale , verbose = verbose , ** kwargs )
94- elif model_type == "ribbon_synapse_structures " :
95- raise NotImplementedError
132+ elif model_type == "ribbon " :
133+ segmentation = _segment_ribbon_AZ ( image , model = model , tiling = tiling , scale = scale , verbose = verbose , ** kwargs )
96134 else :
97135 raise ValueError (f"Unknown model type: { model_type } " )
98136 return segmentation
@@ -108,6 +146,7 @@ def get_model_training_resolution(model_type):
108146 "active_zone" : {"x" : 1.44 , "y" : 1.44 , "z" : 1.44 },
109147 "compartments" : {"x" : 3.47 , "y" : 3.47 , "z" : 3.47 },
110148 "mitochondria" : {"x" : 2.07 , "y" : 2.07 , "z" : 2.07 },
149+ "ribbon" : {"x" : 1.188 , "y" : 1.188 , "z" : 1.188 },
111150 "vesicles_2d" : {"x" : 1.35 , "y" : 1.35 },
112151 "vesicles_3d" : {"x" : 1.35 , "y" : 1.35 , "z" : 1.35 },
113152 "vesicles_cryo" : {"x" : 1.35 , "y" : 1.35 , "z" : 0.88 },
@@ -120,6 +159,7 @@ def get_model_registry():
120159 "active_zone" : "a18f29168aed72edec0f5c2cb1aa9a4baa227812db6082a6538fd38d9f43afb0" ,
121160 "compartments" : "527983720f9eb215c45c4f4493851fd6551810361eda7b79f185a0d304274ee1" ,
122161 "mitochondria" : "24625018a5968b36f39fa9d73b121a32e8f66d0f2c0540d3df2e1e39b3d58186" ,
162+ "ribbon" : "7c947f0ddfabe51a41d9d05c0a6ca7d6b238f43df2af8fffed5552d09bb075a9" ,
123163 "vesicles_2d" : "eb0b74f7000a0e6a25b626078e76a9452019f2d1ea6cf2033073656f4f055df1" ,
124164 "vesicles_3d" : "b329ec1f57f305099c984fbb3d7f6ae4b0ff51ec2fa0fa586df52dad6b84cf29" ,
125165 "vesicles_cryo" : "782f5a21c3cda82c4e4eaeccc754774d5aaed5929f8496eb018aad7daf91661b" ,
@@ -128,6 +168,7 @@ def get_model_registry():
128168 "active_zone" : "https://owncloud.gwdg.de/index.php/s/zvuY342CyQebPsX/download" ,
129169 "compartments" : "https://owncloud.gwdg.de/index.php/s/DnFDeTmDDmZrDDX/download" ,
130170 "mitochondria" : "https://owncloud.gwdg.de/index.php/s/1T542uvzfuruahD/download" ,
171+ "ribbon" : "https://owncloud.gwdg.de/index.php/s/S3b5l0liPP1XPYA/download" ,
131172 "vesicles_2d" : "https://owncloud.gwdg.de/index.php/s/d72QIvdX6LsgXip/download" ,
132173 "vesicles_3d" : "https://owncloud.gwdg.de/index.php/s/A425mkAOSqePDhx/download" ,
133174 "vesicles_cryo" : "https://owncloud.gwdg.de/index.php/s/e2lVdxjCJuZkLJm/download" ,
0 commit comments