1010from  synapse_net .inference .compartments  import  segment_compartments 
1111from  synapse_net .inference .active_zone  import  segment_active_zone 
1212from  synapse_net .inference .inference  import  get_model_path 
13+ from  synapse_net .ground_truth .az_evaluation  import  _get_presynaptic_mask 
1314
1415
1516def  fill_and_filter_vesicles (vesicles : np .ndarray ) ->  np .ndarray :
@@ -130,7 +131,7 @@ def compartment_pred(raw: np.ndarray, compartment_model: str, output_path: str =
130131        else :
131132            print ("Not storing compartment predictions" )
132133
133-     return  seg 
134+     return  seg ,  pred 
134135
135136
136137def  AZ_pred (raw : np .ndarray , AZ_model : str , output_path : str  =  None , store : bool  =  False ) ->  np .ndarray :
@@ -179,7 +180,7 @@ def AZ_pred(raw: np.ndarray, AZ_model: str, output_path: str = None, store: bool
179180    return  seg 
180181
181182
182- def  filter_presynaptic_SV (sv_seg : np .ndarray , compartment_seg : np .ndarray , output_path : str  =  None ,
183+ def  filter_presynaptic_SV (sv_seg : np .ndarray , compartment_seg : np .ndarray , compartment_pred :  np . ndarray ,  output_path : str  =  None ,
183184                          store : bool  =  False , input_path : str  =  None ) ->  np .ndarray :
184185    """ 
185186    Filters synaptic vesicle segmentation to retain only vesicles in the presynaptic region. 
@@ -200,14 +201,16 @@ def filter_presynaptic_SV(sv_seg: np.ndarray, compartment_seg: np.ndarray, outpu
200201    def  n_vesicles (mask , ves ):
201202        return  len (np .unique (ves [mask ])) -  1 
202203
203-     # Find the segment with most vesicles. 
204+     ''' # Find the segment with most vesicles.
204205    props = regionprops(compartment_seg, intensity_image=vesicles_pp, extra_properties=[n_vesicles]) 
205206    compartment_ids = [prop.label for prop in props] 
206207    vesicle_counts = [prop.n_vesicles for prop in props] 
207208    if len(compartment_ids) == 0: 
208209        mask = np.ones(compartment_seg.shape, dtype="bool") 
209210    else: 
210-         mask  =  (compartment_seg  ==  compartment_ids [np .argmax (vesicle_counts )]).astype ("uint8" )
211+         mask = (compartment_seg == compartment_ids[np.argmax(vesicle_counts)]).astype("uint8")''' 
212+ 
213+     mask  =  _get_presynaptic_mask (compartment_pred , vesicles_pp )
211214
212215    # Filter all vesicles that are not in the mask. 
213216    props  =  regionprops (vesicles_pp , mask )
@@ -274,13 +277,13 @@ def run_predictions(input_path: str, output_path: str = None, store: bool = Fals
274277    sv_seg  =  SV_pred (raw , SV_model , output_path , store )
275278
276279    print ("Running compartment prediction" )
277-     comp_seg  =  compartment_pred (raw , compartment_model , output_path , store )
280+     comp_seg ,  comp_pred  =  compartment_pred (raw , compartment_model , output_path , store )
278281
279282    print ("Running AZ prediction" )
280283    az_seg  =  AZ_pred (raw , AZ_model , output_path , store )
281284
282285    print ("Filtering the presynaptic SV" )
283-     presyn_SV_seg  =  filter_presynaptic_SV (sv_seg , comp_seg , output_path , store , input_path )
286+     presyn_SV_seg  =  filter_presynaptic_SV (sv_seg , comp_seg , comp_pred ,  output_path , store , input_path )
284287
285288    print ("Done with predictions" )
286289
0 commit comments