@@ -400,6 +400,43 @@ def _commit_impl(viewer, layer, preserve_committed):
400400 return id_offset , seg , mask , bb
401401
402402
403+ def _get_auto_segmentation_options (state , object_ids ):
404+ widget = state .widgets ["autosegment" ]
405+
406+ segmentation_options = {"object_ids" : [int (object_id ) for object_id in object_ids ]}
407+ if widget .with_decoder :
408+ segmentation_options ["boundary_distance_thresh" ] = widget .boundary_distance_thresh
409+ segmentation_options ["center_distance_thresh" ] = widget .center_distance_thresh
410+ else :
411+ segmentation_options ["pred_iou_thresh" ] = widget .pred_iou_thresh
412+ segmentation_options ["stability_score_thresh" ] = widget .stability_score_thresh
413+ segmentation_options ["box_nms_thresh" ] = widget .box_nms_thresh
414+
415+ segmentation_options ["min_object_size" ] = widget .min_object_size
416+ segmentation_options ["with_background" ] = widget .with_background
417+
418+ if widget .volumetric :
419+ segmentation_options ["apply_to_volume" ] = widget .apply_to_volume
420+ segmentation_options ["gap_closing" ] = widget .gap_closing
421+ segmentation_options ["min_extent" ] = widget .min_extent
422+
423+ return segmentation_options
424+
425+
426+ def _get_promptable_segmentation_options (state , object_ids ):
427+ segmentation_options = {"object_ids" : [int (object_id ) for object_id in object_ids ]}
428+ is_tracking = False
429+ if "segment_nd" in state .widgets :
430+ widget = state .widgets ["segment_nd" ]
431+ segmentation_options ["projection" ] = widget .projection
432+ segmentation_options ["iou_threshold" ] = widget .iou_threshold
433+ segmentation_options ["box_extension" ] = widget .box_extension
434+ if widget .tracking :
435+ segmentation_options ["motion_smoothing" ] = widget .motion_smoothing
436+ is_tracking = True
437+ return segmentation_options , is_tracking
438+
439+
403440def _commit_to_file (path , viewer , layer , seg , mask , bb , extra_attrs = None ):
404441
405442 # NOTE: zarr-python is quite inefficient and writes empty blocks.
@@ -413,23 +450,42 @@ def _commit_to_file(path, viewer, layer, seg, mask, bb, extra_attrs=None):
413450 json .dump ({"zarr_format" : 2 }, f )
414451
415452 f = z5py .ZarrFile (path , "a" )
453+ state = AnnotatorState ()
416454
417- # Write metadata about the model that's being used etc.
418- # Only if it's not written to the file yet.
419- if "data_signature" not in f .attrs :
420- state = AnnotatorState ()
455+ def _save_signature (f , data_signature ):
421456 embeds = state .widgets ["embeddings" ]
422457 tile_shape , halo = _process_tiling_inputs (embeds .tile_x , embeds .tile_y , embeds .halo_x , embeds .halo_y )
423458 signature = util ._get_embedding_signature (
424459 input_ = None , # We don't need this because we pass the data signature.
425460 predictor = state .predictor ,
426461 tile_shape = tile_shape ,
427462 halo = halo ,
428- data_signature = state . data_signature ,
463+ data_signature = data_signature ,
429464 )
430465 for key , val in signature .items ():
431466 f .attrs [key ] = val
432467
468+ # If the data signature is saved in the file already,
469+ # then we check if saved data signature and data signature of our image agree.
470+ # If not, this file was used for committing objects from another file.
471+ if "data_signature" in f .attrs :
472+ saved_signature = f .attrs ["data_signature" ]
473+ current_signature = state .data_signature
474+ if saved_signature != current_signature : # Signatures disagree.
475+ msg = f"The commit_path { path } was already used for saving annotations for different image data:\n "
476+ msg += f"The data signatures are different: { saved_signature } != { current_signature } .\n "
477+ msg += "Press 'Ok' to remove the data already stored in that file and continue annotation.\n "
478+ msg += "Otherwise please select a different file path."
479+ skip_clear = _generate_message ("info" , msg )
480+ if skip_clear :
481+ return
482+ else :
483+ f = z5py .ZarrFile (path , "w" )
484+ _save_signature (f , current_signature )
485+ # Otherwise (data signature not saved yet), write the current signature.
486+ else :
487+ _save_signature (f , state .data_signature )
488+
433489 # Write the segmentation.
434490 full_shape = viewer .layers ["committed_objects" ].data .shape
435491 block_shape = util .get_block_shape (full_shape )
@@ -445,47 +501,81 @@ def _commit_to_file(path, viewer, layer, seg, mask, bb, extra_attrs=None):
445501 if extra_attrs is not None :
446502 f .attrs .update (extra_attrs )
447503
448- # If we run commit from the automatic segmentation we don't have
449- # any prompts and so don't need to commit anything else.
504+ # Get the commit history and the objects that are being commited.
505+ commit_history = f .attrs .get ("commit_history" , [])
506+ object_ids = np .unique (seg [mask ])
507+
508+ # We committed an automatic segmentation.
450509 if layer == "auto_segmentation" :
451- # TODO write the settings for the auto segmentation widget.
510+ # Save the settings of the segmentation widget.
511+ segmentation_options = _get_auto_segmentation_options (state , object_ids )
512+ commit_history .append ({"auto_segmentation" : segmentation_options })
513+
514+ # Write the commit history.
515+ f .attrs ["commit_history" ] = commit_history
516+
517+ # If we run commit from the automatic segmentation we don't have
518+ # any prompts and so don't need to commit anything else.
452519 return
453520
454- def write_prompts (object_id , prompts , point_prompts ):
521+ segmentation_options , is_tracking = _get_promptable_segmentation_options (state , object_ids )
522+ commit_history .append ({"current_object" : segmentation_options })
523+
524+ def write_prompts (object_id , prompts , point_prompts , point_labels , track_state = None ):
455525 g = f .create_group (f"prompts/{ object_id } " )
456526 if prompts is not None and len (prompts ) > 0 :
457527 data = np .array (prompts )
458528 g .create_dataset ("prompts" , data = data , chunks = data .shape )
459529 if point_prompts is not None and len (point_prompts ) > 0 :
460530 g .create_dataset ("point_prompts" , data = point_prompts , chunks = point_prompts .shape )
531+ ds = g .create_dataset ("point_labels" , data = point_labels , chunks = point_labels .shape )
532+ if track_state is not None :
533+ ds .attrs ["track_state" ] = track_state .tolist ()
534+
535+ # Get the prompts from the layers.
536+ prompts = viewer .layers ["prompts" ].data
537+ point_layer = viewer .layers ["point_prompts" ]
538+ point_prompts = point_layer .data
539+ point_labels = point_layer .properties ["label" ]
540+ if len (point_prompts ) > 0 :
541+ point_labels = np .array ([1 if label == "positive" else 0 for label in point_labels ])
542+ assert len (point_prompts ) == len (point_labels ), \
543+ f"Number of point prompts and labels disagree: { len (point_prompts )} != { len (point_labels )} "
461544
462- # TODO write the settings for the segmentation widget if necessary.
463545 # Commit the prompts for all the objects in the commit.
464- object_ids = np .unique (seg [mask ])
465546 if len (object_ids ) == 1 : # We only have a single object.
466- write_prompts (object_ids [0 ], viewer .layers ["prompts" ].data , viewer .layers ["point_prompts" ].data )
467- else :
468- # TODO this logic has to be updated to be compatible with the new batched prompting
469- have_prompts = len (viewer .layers ["prompts" ].data ) > 0
470- have_point_prompts = len (viewer .layers ["point_prompts" ].data ) > 0
471- if have_prompts and not have_point_prompts :
472- prompts = viewer .layers ["prompts" ].data
473- point_prompts = None
474- elif not have_prompts and have_point_prompts :
475- prompts = None
476- point_prompts = viewer .layers ["point_prompts" ].data
477- else :
478- msg = "Got multiple objects from interactive segmentation with box and point prompts." if (
479- have_prompts and have_point_prompts
480- ) else "Got multiple objects from interactive segmentation with neither box or point prompts."
481- raise RuntimeError (msg )
482-
547+ write_prompts (object_ids [0 ], prompts , point_prompts , point_labels )
548+
549+ elif is_tracking : # We have multiple objects from tracking a lineage with divisions.
550+ track_ids_points = np .array (point_layer .properties ["track_id" ])
551+ track_ids_prompts = np .array (viewer .layers ["prompts" ].properties ["track_id" ])
552+
553+ unique_track_ids = np .unique (track_ids_points )
554+ assert len (unique_track_ids ) == len (object_ids )
555+ track_state = np .array (point_layer .properties ["state" ])
556+ for track_id , object_id in zip (unique_track_ids , object_ids ):
557+ this_prompts = None if len (prompts ) == 0 else prompts [track_ids_prompts == track_id ]
558+ point_mask = track_ids_points == track_id
559+ this_points , this_labels , this_track_state = \
560+ point_prompts [point_mask ], point_labels [point_mask ], track_state [point_mask ]
561+ write_prompts (object_id , this_prompts , this_points , this_labels , track_state = this_track_state )
562+
563+ else : # We have multiple objects, which are the result from batched interactive segmentation.
564+ # Note: we can't match exact object ids to their prompts, for batched segmentation.
565+ # We first write the objects from box prompts, then from point prompts.
566+ n_prompts , n_points = len (prompts ), len (point_prompts )
567+ assert n_prompts + n_points == len (object_ids ), \
568+ f"Number of prompts and objects disagree: { n_prompts } + { n_points } != { len (object_ids )} "
483569 for i , object_id in enumerate (object_ids ):
484- write_prompts (
485- object_id ,
486- None if prompts is None else prompts [i :i + 1 ],
487- None if point_prompts is None else point_prompts [i :i + 1 ]
488- )
570+ if i < n_prompts :
571+ this_prompts , this_points , this_labels = prompts [i :i + 1 ], None , None
572+ else :
573+ j = i - n_prompts
574+ this_prompts , this_points , this_labels = None , point_prompts [j :j + 1 ], point_labels [j :j + 1 ]
575+ write_prompts (object_id , this_prompts , this_points , this_labels )
576+
577+ # Write the commit history.
578+ f .attrs ["commit_history" ] = commit_history
489579
490580
491581@magic_factory (
@@ -629,21 +719,21 @@ def settings_widget(cache_directory: Optional[Path] = util.get_cache_directory()
629719 print (f"micro-sam cache directory set to: { cache_directory } " )
630720
631721
632- def _generate_message (message_type , message ) -> bool :
722+ def _generate_message (message_type : str , message : str ) -> bool :
633723 """
634724 Displays a message dialog based on the provided message type.
635725
636726 Args:
637- message_type (str) : The type of message to display. Valid options are:
727+ message_type: The type of message to display. Valid options are:
638728 - "error": Displays a critical error message with an "Ok" button.
639729 - "info": Displays an informational message in a separate dialog box.
640730 The user can dismiss it by either clicking "Ok" or closing the dialog.
641- message (str) : The message content to be displayed in the dialog.
731+ message: The message content to be displayed in the dialog.
642732
643733 Returns:
644- bool: A flag indicating whether the user aborted the operation based on the
645- message type. This flag is only set for "info" messages where the user
646- can choose to cancel (rejected).
734+ A flag indicating whether the user aborted the operation based on the
735+ message type. This flag is only set for "info" messages where the user
736+ can choose to cancel (rejected).
647737
648738 Raises:
649739 ValueError: If an invalid message type is provided.
@@ -659,6 +749,8 @@ def _generate_message(message_type, message) -> bool:
659749 if result == QtWidgets .QDialog .Rejected : # Check for cancel
660750 abort = True # Set flag directly in calling function
661751 return abort
752+ else :
753+ raise ValueError (f"Invalid message type { message_type } " )
662754
663755
664756def _validate_embeddings (viewer : "napari.viewer.Viewer" ):
@@ -1140,8 +1232,7 @@ def _create_settings_widget(self):
11401232 return settings
11411233
11421234 def _validate_inputs (self ):
1143- """
1144- Validates the inputs for the annotation process and returns a dictionary
1235+ """Validates the inputs for the annotation process and returns a dictionary
11451236 containing information for message generation, or False if no messages are needed.
11461237
11471238 This function performs the following checks:
0 commit comments