Skip to content

Commit a71f6c1

Browse files
Update commit to file path functionality (#959)
1 parent dd61633 commit a71f6c1

File tree

1 file changed

+132
-41
lines changed

1 file changed

+132
-41
lines changed

micro_sam/sam_annotator/_widgets.py

Lines changed: 132 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,43 @@ def _commit_impl(viewer, layer, preserve_committed):
400400
return id_offset, seg, mask, bb
401401

402402

403+
def _get_auto_segmentation_options(state, object_ids):
404+
widget = state.widgets["autosegment"]
405+
406+
segmentation_options = {"object_ids": [int(object_id) for object_id in object_ids]}
407+
if widget.with_decoder:
408+
segmentation_options["boundary_distance_thresh"] = widget.boundary_distance_thresh
409+
segmentation_options["center_distance_thresh"] = widget.center_distance_thresh
410+
else:
411+
segmentation_options["pred_iou_thresh"] = widget.pred_iou_thresh
412+
segmentation_options["stability_score_thresh"] = widget.stability_score_thresh
413+
segmentation_options["box_nms_thresh"] = widget.box_nms_thresh
414+
415+
segmentation_options["min_object_size"] = widget.min_object_size
416+
segmentation_options["with_background"] = widget.with_background
417+
418+
if widget.volumetric:
419+
segmentation_options["apply_to_volume"] = widget.apply_to_volume
420+
segmentation_options["gap_closing"] = widget.gap_closing
421+
segmentation_options["min_extent"] = widget.min_extent
422+
423+
return segmentation_options
424+
425+
426+
def _get_promptable_segmentation_options(state, object_ids):
427+
segmentation_options = {"object_ids": [int(object_id) for object_id in object_ids]}
428+
is_tracking = False
429+
if "segment_nd" in state.widgets:
430+
widget = state.widgets["segment_nd"]
431+
segmentation_options["projection"] = widget.projection
432+
segmentation_options["iou_threshold"] = widget.iou_threshold
433+
segmentation_options["box_extension"] = widget.box_extension
434+
if widget.tracking:
435+
segmentation_options["motion_smoothing"] = widget.motion_smoothing
436+
is_tracking = True
437+
return segmentation_options, is_tracking
438+
439+
403440
def _commit_to_file(path, viewer, layer, seg, mask, bb, extra_attrs=None):
404441

405442
# NOTE: zarr-python is quite inefficient and writes empty blocks.
@@ -413,23 +450,42 @@ def _commit_to_file(path, viewer, layer, seg, mask, bb, extra_attrs=None):
413450
json.dump({"zarr_format": 2}, f)
414451

415452
f = z5py.ZarrFile(path, "a")
453+
state = AnnotatorState()
416454

417-
# Write metadata about the model that's being used etc.
418-
# Only if it's not written to the file yet.
419-
if "data_signature" not in f.attrs:
420-
state = AnnotatorState()
455+
def _save_signature(f, data_signature):
421456
embeds = state.widgets["embeddings"]
422457
tile_shape, halo = _process_tiling_inputs(embeds.tile_x, embeds.tile_y, embeds.halo_x, embeds.halo_y)
423458
signature = util._get_embedding_signature(
424459
input_=None, # We don't need this because we pass the data signature.
425460
predictor=state.predictor,
426461
tile_shape=tile_shape,
427462
halo=halo,
428-
data_signature=state.data_signature,
463+
data_signature=data_signature,
429464
)
430465
for key, val in signature.items():
431466
f.attrs[key] = val
432467

468+
# If the data signature is saved in the file already,
469+
# then we check if saved data signature and data signature of our image agree.
470+
# If not, this file was used for committing objects from another file.
471+
if "data_signature" in f.attrs:
472+
saved_signature = f.attrs["data_signature"]
473+
current_signature = state.data_signature
474+
if saved_signature != current_signature: # Signatures disagree.
475+
msg = f"The commit_path {path} was already used for saving annotations for different image data:\n"
476+
msg += f"The data signatures are different: {saved_signature} != {current_signature}.\n"
477+
msg += "Press 'Ok' to remove the data already stored in that file and continue annotation.\n"
478+
msg += "Otherwise please select a different file path."
479+
skip_clear = _generate_message("info", msg)
480+
if skip_clear:
481+
return
482+
else:
483+
f = z5py.ZarrFile(path, "w")
484+
_save_signature(f, current_signature)
485+
# Otherwise (data signature not saved yet), write the current signature.
486+
else:
487+
_save_signature(f, state.data_signature)
488+
433489
# Write the segmentation.
434490
full_shape = viewer.layers["committed_objects"].data.shape
435491
block_shape = util.get_block_shape(full_shape)
@@ -445,47 +501,81 @@ def _commit_to_file(path, viewer, layer, seg, mask, bb, extra_attrs=None):
445501
if extra_attrs is not None:
446502
f.attrs.update(extra_attrs)
447503

448-
# If we run commit from the automatic segmentation we don't have
449-
# any prompts and so don't need to commit anything else.
504+
# Get the commit history and the objects that are being commited.
505+
commit_history = f.attrs.get("commit_history", [])
506+
object_ids = np.unique(seg[mask])
507+
508+
# We committed an automatic segmentation.
450509
if layer == "auto_segmentation":
451-
# TODO write the settings for the auto segmentation widget.
510+
# Save the settings of the segmentation widget.
511+
segmentation_options = _get_auto_segmentation_options(state, object_ids)
512+
commit_history.append({"auto_segmentation": segmentation_options})
513+
514+
# Write the commit history.
515+
f.attrs["commit_history"] = commit_history
516+
517+
# If we run commit from the automatic segmentation we don't have
518+
# any prompts and so don't need to commit anything else.
452519
return
453520

454-
def write_prompts(object_id, prompts, point_prompts):
521+
segmentation_options, is_tracking = _get_promptable_segmentation_options(state, object_ids)
522+
commit_history.append({"current_object": segmentation_options})
523+
524+
def write_prompts(object_id, prompts, point_prompts, point_labels, track_state=None):
455525
g = f.create_group(f"prompts/{object_id}")
456526
if prompts is not None and len(prompts) > 0:
457527
data = np.array(prompts)
458528
g.create_dataset("prompts", data=data, chunks=data.shape)
459529
if point_prompts is not None and len(point_prompts) > 0:
460530
g.create_dataset("point_prompts", data=point_prompts, chunks=point_prompts.shape)
531+
ds = g.create_dataset("point_labels", data=point_labels, chunks=point_labels.shape)
532+
if track_state is not None:
533+
ds.attrs["track_state"] = track_state.tolist()
534+
535+
# Get the prompts from the layers.
536+
prompts = viewer.layers["prompts"].data
537+
point_layer = viewer.layers["point_prompts"]
538+
point_prompts = point_layer.data
539+
point_labels = point_layer.properties["label"]
540+
if len(point_prompts) > 0:
541+
point_labels = np.array([1 if label == "positive" else 0 for label in point_labels])
542+
assert len(point_prompts) == len(point_labels), \
543+
f"Number of point prompts and labels disagree: {len(point_prompts)} != {len(point_labels)}"
461544

462-
# TODO write the settings for the segmentation widget if necessary.
463545
# Commit the prompts for all the objects in the commit.
464-
object_ids = np.unique(seg[mask])
465546
if len(object_ids) == 1: # We only have a single object.
466-
write_prompts(object_ids[0], viewer.layers["prompts"].data, viewer.layers["point_prompts"].data)
467-
else:
468-
# TODO this logic has to be updated to be compatible with the new batched prompting
469-
have_prompts = len(viewer.layers["prompts"].data) > 0
470-
have_point_prompts = len(viewer.layers["point_prompts"].data) > 0
471-
if have_prompts and not have_point_prompts:
472-
prompts = viewer.layers["prompts"].data
473-
point_prompts = None
474-
elif not have_prompts and have_point_prompts:
475-
prompts = None
476-
point_prompts = viewer.layers["point_prompts"].data
477-
else:
478-
msg = "Got multiple objects from interactive segmentation with box and point prompts." if (
479-
have_prompts and have_point_prompts
480-
) else "Got multiple objects from interactive segmentation with neither box or point prompts."
481-
raise RuntimeError(msg)
482-
547+
write_prompts(object_ids[0], prompts, point_prompts, point_labels)
548+
549+
elif is_tracking: # We have multiple objects from tracking a lineage with divisions.
550+
track_ids_points = np.array(point_layer.properties["track_id"])
551+
track_ids_prompts = np.array(viewer.layers["prompts"].properties["track_id"])
552+
553+
unique_track_ids = np.unique(track_ids_points)
554+
assert len(unique_track_ids) == len(object_ids)
555+
track_state = np.array(point_layer.properties["state"])
556+
for track_id, object_id in zip(unique_track_ids, object_ids):
557+
this_prompts = None if len(prompts) == 0 else prompts[track_ids_prompts == track_id]
558+
point_mask = track_ids_points == track_id
559+
this_points, this_labels, this_track_state = \
560+
point_prompts[point_mask], point_labels[point_mask], track_state[point_mask]
561+
write_prompts(object_id, this_prompts, this_points, this_labels, track_state=this_track_state)
562+
563+
else: # We have multiple objects, which are the result from batched interactive segmentation.
564+
# Note: we can't match exact object ids to their prompts, for batched segmentation.
565+
# We first write the objects from box prompts, then from point prompts.
566+
n_prompts, n_points = len(prompts), len(point_prompts)
567+
assert n_prompts + n_points == len(object_ids), \
568+
f"Number of prompts and objects disagree: {n_prompts} + {n_points} != {len(object_ids)}"
483569
for i, object_id in enumerate(object_ids):
484-
write_prompts(
485-
object_id,
486-
None if prompts is None else prompts[i:i+1],
487-
None if point_prompts is None else point_prompts[i:i+1]
488-
)
570+
if i < n_prompts:
571+
this_prompts, this_points, this_labels = prompts[i:i+1], None, None
572+
else:
573+
j = i - n_prompts
574+
this_prompts, this_points, this_labels = None, point_prompts[j:j+1], point_labels[j:j+1]
575+
write_prompts(object_id, this_prompts, this_points, this_labels)
576+
577+
# Write the commit history.
578+
f.attrs["commit_history"] = commit_history
489579

490580

491581
@magic_factory(
@@ -629,21 +719,21 @@ def settings_widget(cache_directory: Optional[Path] = util.get_cache_directory()
629719
print(f"micro-sam cache directory set to: {cache_directory}")
630720

631721

632-
def _generate_message(message_type, message) -> bool:
722+
def _generate_message(message_type: str, message: str) -> bool:
633723
"""
634724
Displays a message dialog based on the provided message type.
635725
636726
Args:
637-
message_type (str): The type of message to display. Valid options are:
727+
message_type: The type of message to display. Valid options are:
638728
- "error": Displays a critical error message with an "Ok" button.
639729
- "info": Displays an informational message in a separate dialog box.
640730
The user can dismiss it by either clicking "Ok" or closing the dialog.
641-
message (str): The message content to be displayed in the dialog.
731+
message: The message content to be displayed in the dialog.
642732
643733
Returns:
644-
bool: A flag indicating whether the user aborted the operation based on the
645-
message type. This flag is only set for "info" messages where the user
646-
can choose to cancel (rejected).
734+
A flag indicating whether the user aborted the operation based on the
735+
message type. This flag is only set for "info" messages where the user
736+
can choose to cancel (rejected).
647737
648738
Raises:
649739
ValueError: If an invalid message type is provided.
@@ -659,6 +749,8 @@ def _generate_message(message_type, message) -> bool:
659749
if result == QtWidgets.QDialog.Rejected: # Check for cancel
660750
abort = True # Set flag directly in calling function
661751
return abort
752+
else:
753+
raise ValueError(f"Invalid message type {message_type}")
662754

663755

664756
def _validate_embeddings(viewer: "napari.viewer.Viewer"):
@@ -1140,8 +1232,7 @@ def _create_settings_widget(self):
11401232
return settings
11411233

11421234
def _validate_inputs(self):
1143-
"""
1144-
Validates the inputs for the annotation process and returns a dictionary
1235+
"""Validates the inputs for the annotation process and returns a dictionary
11451236
containing information for message generation, or False if no messages are needed.
11461237
11471238
This function performs the following checks:

0 commit comments

Comments
 (0)