Skip to content

Commit 66ed39b

Browse files
Update to sgn detection training
1 parent f797bc7 commit 66ed39b

File tree

3 files changed

+29
-9
lines changed

3 files changed

+29
-9
lines changed

flamingo_tools/segmentation/postprocessing.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -774,7 +774,8 @@ def split_object(object_id):
774774
nonlocal offset
775775

776776
row = segmentation_table[segmentation_table.label_id == object_id]
777-
if min_size and row.n_pixels.values[0] < min_size:
777+
if row.n_pixels.values[0] < min_size:
778+
# print(object_id, ": min-size")
778779
return [object_id]
779780

780781
bb_min = np.array([
@@ -788,6 +789,12 @@ def split_object(object_id):
788789
bb_max = np.minimum(bb_max.astype(int) + 1, np.array(list(segmentation.shape)))
789790
bb = tuple(slice(mi, ma) for mi, ma in zip(bb_min, bb_max))
790791

792+
# This is due to segmentation artifacts.
793+
bb_shape = bb_max - bb_min
794+
if (bb_shape > 500).any():
795+
print(object_id, "has a too large shape:", bb_shape)
796+
return [object_id]
797+
791798
seg = segmentation[bb]
792799
mask = ~find_boundaries(seg)
793800
dist = distance_transform_edt(mask, sampling=resolution)
@@ -798,6 +805,7 @@ def split_object(object_id):
798805
maxima = peak_local_max(dist, min_distance=3, exclude_border=True)
799806

800807
if len(maxima) == 1:
808+
# print(object_id, ": max len")
801809
return [object_id]
802810

803811
with lock:
@@ -819,14 +827,20 @@ def split_object(object_id):
819827

820828
keep_ids = seg_ids[sizes > min_size]
821829
if len(keep_ids) < 2:
830+
# print(object_id, ": keep-id")
822831
return [object_id]
823832

824833
elif len(keep_ids) != len(seg_ids):
825834
new_seg[~np.isin(new_seg, keep_ids)] = 0
826835
new_seg = watershed(hmap, markers=new_seg, mask=seg_mask)
827836

828-
output[bb][seg_mask] = new_seg[seg_mask]
829-
return seg_ids.tolist()
837+
with lock:
838+
out = output[bb]
839+
out[seg_mask] = new_seg[seg_mask]
840+
output[bb] = out
841+
842+
# print(object_id, ":", len(keep_ids))
843+
return keep_ids.tolist()
830844

831845
# import napari
832846
# v = napari.Viewer()
@@ -839,11 +853,15 @@ def split_object(object_id):
839853
if component_labels is None:
840854
object_ids = segmentation_table.label_id.values
841855
else:
842-
object_ids = segmentation_table[segmentation_table.isin(component_labels)].label_id.values
856+
object_ids = segmentation_table[segmentation_table.component_labels.isin(component_labels)].label_id.values
843857

844858
if n_threads is None:
845859
n_threads = mp.cpu_count()
846860

861+
# new_id_mapping = []
862+
# for object_id in tqdm(object_ids, desc="Split non-convex objects"):
863+
# new_id_mapping.append(split_object(object_id))
864+
847865
with futures.ThreadPoolExecutor(n_threads) as tp:
848866
new_id_mapping = list(
849867
tqdm(tp.map(split_object, object_ids), total=len(object_ids), desc="Split non-convex objects")

scripts/la-vision/train_sgn_detection.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@
1212
from utils.training.training import supervised_training # noqa
1313
from detection_dataset import DetectionDataset, MinPointSampler # noqa
1414

15-
ROOT = "./la-vision-sgn-new" # noqa
15+
# ROOT = "./la-vision-sgn-new" # noqa
16+
ROOT = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/training_data/SGN/sgn-detection"
1617

1718
TRAIN = os.path.join(ROOT, "images")
1819
TRAIN_EMPTY = os.path.join(ROOT, "empty_images")
@@ -24,6 +25,7 @@
2425
def _get_paths(split, train_folder, label_folder, n=None):
2526
image_paths = sorted(glob(os.path.join(train_folder, "*.tif")))
2627
label_paths = sorted(glob(os.path.join(label_folder, "*.csv")))
28+
assert len(image_paths) > 0
2729
assert len(image_paths) == len(label_paths)
2830
if n is not None:
2931
image_paths, label_paths = image_paths[:n], label_paths[:n]

scripts/synapse_marker_detection/detection_dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,12 +206,12 @@ def _get_sample(self, index):
206206
raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
207207

208208
# For synapse detection.
209-
# label = process_labels(coords, shape, self.sigma, self.eps, bb=bb)
209+
label = process_labels(coords, shape, self.sigma, self.eps, bb=bb)
210210

211211
# For SGN detection with data specfic hacks
212-
label = process_labels_hacky(coords, shape, self.sigma, self.eps, bb=bb)
213-
gap = 6
214-
raw_patch, label = raw_patch[gap:-gap], label[gap:-gap]
212+
# label = process_labels_hacky(coords, shape, self.sigma, self.eps, bb=bb)
213+
# gap = 8
214+
# raw_patch, label = raw_patch[gap:-gap], label[gap:-gap]
215215

216216
have_label_channels = label.ndim == 4
217217
if have_label_channels:

0 commit comments

Comments
 (0)