@@ -104,6 +104,40 @@ def process_labels_hacky(coords, shape, sigma, eps, bb=None):
104104 return labels
105105
106106
107+ def process_labels_stamped (coords , shape , sigma , eps , bb ):
108+ offset = 7
109+
110+ if bb :
111+ (z_min , z_max ), (y_min , y_max ), (x_min , x_max ) = [(s .start , s .stop ) for s in bb ]
112+ restricted_shape = (z_max - z_min , y_max - y_min , x_max - x_min )
113+ full_shape = tuple (sh + 2 * offset for sh in restricted_shape )
114+ labels = np .zeros (full_shape , dtype = "float32" )
115+ shape = restricted_shape
116+ else :
117+ full_shape = tuple (sh + 2 * offset for sh in shape )
118+
119+ labels = np .zeros (full_shape , dtype = "float32" )
120+
121+ z , y , x = coords
122+ if len (z ) == 0 :
123+ return np .zeros (shape , dtype = "float32" )
124+ coordinates = np .concatenate ([z [:, None ], y [:, None ], x [:, None ]], axis = 1 )
125+
126+ stamp = np .zeros ((2 * offset , 2 * offset , 2 * offset ), dtype = "float32" )
127+ stamp [offset - 1 , offset - 1 , offset - 1 ] = 1
128+ stamp = gaussian (stamp , sigma = sigma )
129+ stamp /= stamp .max ()
130+
131+ for coord in coordinates :
132+ bb = tuple (slice (co + offset - offset , co + offset + offset ) for co in coord )
133+ val = np .maximum (labels [bb ], stamp )
134+ labels [bb ] = val
135+
136+ labels = labels [offset :- offset , offset :- offset , offset :- offset ]
137+ assert labels .shape == shape
138+ return labels
139+
140+
107141class DetectionDataset (torch .utils .data .Dataset ):
108142 max_sampling_attempts = 500
109143
@@ -209,7 +243,8 @@ def _get_sample(self, index):
209243 # label = process_labels(coords, shape, self.sigma, self.eps, bb=bb)
210244
211245 # For SGN detection with data specfic hacks
212- label = process_labels_hacky (coords , shape , self .sigma , self .eps , bb = bb )
246+ # label = process_labels_hacky(coords, shape, self.sigma, self.eps, bb=bb)
247+ label = process_labels_stamped (coords , shape , self .sigma , self .eps , bb = bb )
213248 # Having this halo actually makes sense in general!
214249 gap = 8
215250 gap_bb = np .s_ [gap :- gap , gap :- gap , gap :- gap ]
0 commit comments