diff --git a/.gitignore b/.gitignore index fd2845d..cb794c2 100644 --- a/.gitignore +++ b/.gitignore @@ -83,3 +83,4 @@ venv/ # written by setuptools_scm **/_version.py +*.txt \ No newline at end of file diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..fbd3a51 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "src/co-tracker"] + path = src/co-tracker + url = https://github.com/facebookresearch/co-tracker diff --git a/.napari/DESCRIPTION.md b/.napari/DESCRIPTION.md index 9183c54..4033cba 100644 --- a/.napari/DESCRIPTION.md +++ b/.napari/DESCRIPTION.md @@ -5,18 +5,18 @@ the functionality of your plugin. Its content will be rendered on your plugin's napari hub page. The sections below are given as a guide for the flow of information only, and -are in no way prescriptive. You should feel free to merge, remove, add and -rename sections at will to make this document work best for your plugin. +are in no way prescriptive. You should feel free to merge, remove, add and +rename sections at will to make this document work best for your plugin. # Description -This should be a detailed description of the context of your plugin and its +This should be a detailed description of the context of your plugin and its intended purpose. If you have videos or screenshots of your plugin in action, you should include them -here as well, to make them front and center for new users. +here as well, to make them front and center for new users. -You should use absolute links to these assets, so that we can easily display them +You should use absolute links to these assets, so that we can easily display them on the hub. The easiest way to include a video is to use a GIF, for example hosted on imgur. You can then reference this GIF as an image. @@ -59,7 +59,7 @@ anywhere, feel free to also include this information here. This section should go through step-by-step examples of how your plugin should be used. Where your plugin provides multiple dock widgets or functions, you should split these out into separate subsections for easy browsing. Include screenshots and videos -wherever possible to elucidate your descriptions. +wherever possible to elucidate your descriptions. Ideally, this section should start with minimal examples for those who just want a quick overview of the plugin's functionality, but you should definitely link out to @@ -72,8 +72,8 @@ for the majority of plugins. They will include instructions to pip install, and to install via napari itself. Most plugins can be installed out-of-the-box by just specifying the package requirements -over in `setup.cfg`. However, if your plugin has any more complex dependencies, or -requires any additional preparation before (or after) installation, you should add +over in `setup.cfg`. However, if your plugin has any more complex dependencies, or +requires any additional preparation before (or after) installation, you should add this information here. # Getting Help diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c5d1ce5..350dc89 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,35 +5,26 @@ repos: - id: check-docstring-first - id: end-of-file-fixer - id: trailing-whitespace + - id: check-yaml + - id: check-added-large-files + args: [--maxkb=5000] + - id: check-toml - repo: https://github.com/asottile/setup-cfg-fmt rev: v1.20.0 hooks: - id: setup-cfg-fmt - - repo: https://github.com/PyCQA/flake8 - rev: 4.0.1 + - repo: https://github.com/charliermarsh/ruff-pre-commit + # Ruff version. + rev: 'v0.0.262' hooks: - - id: flake8 - additional_dependencies: [flake8-typing-imports==1.7.0] - - repo: https://github.com/myint/autoflake - rev: v1.4 - hooks: - - id: autoflake - args: ["--in-place", "--remove-all-unused-imports"] - - repo: https://github.com/PyCQA/isort - rev: 5.10.1 - hooks: - - id: isort + - id: ruff + args: [ --fix, --exit-non-zero-on-fix ] - repo: https://github.com/psf/black - rev: 22.1.0 + rev: 23.3.0 hooks: - id: black - - repo: https://github.com/asottile/pyupgrade - rev: v2.31.0 - hooks: - - id: pyupgrade - args: [--py37-plus, --keep-runtime-typing] - repo: https://github.com/tlambert03/napari-plugin-checks - rev: v0.2.0 + rev: v0.3.0 hooks: - id: napari-plugin-checks # https://mypy.readthedocs.io/en/stable/introduction.html diff --git a/pyproject.toml b/pyproject.toml index f8296a1..c7460b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,6 +2,58 @@ requires = ["setuptools", "wheel", "setuptools_scm"] build-backend = "setuptools.build_meta" +[tool.ruff] +target-version = "py38" +select = [ + "E", "F", "W", + "A", + "B", + "G", + "I", + "PT", + "SIM", + "NPY", +] +# Never enforce `E501` (line length violations) and 'E741' (ambiguous variable names) +# and 'G004' (do not use f-strings in logging) +# and 'A003' (Shadowing python builtins) +# and 'F401' (imported but unused) +ignore = ["E501", "E741", "G004", "A003", "F401"] +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", + "docs/conf.py", + "napari_cellseg3d/_tests/conftest.py", +] + +[tool.ruff.pydocstyle] +convention = "google" + +[tool.black] +line-length = 79 + +[tool.isort] +profile = "black" +line_length = 79 [tool.setuptools_scm] write_to = "src/napari_deeplabcut/_version.py" diff --git a/src/napari_deeplabcut/_tracking_utils.py b/src/napari_deeplabcut/_tracking_utils.py new file mode 100644 index 0000000..56acb93 --- /dev/null +++ b/src/napari_deeplabcut/_tracking_utils.py @@ -0,0 +1,432 @@ +### ------------- Custom widgets for tracking module -------------- ### +import logging +import threading +from datetime import datetime +from functools import partial +from typing import Optional + +import napari +from qtpy import QtCore +from qtpy.QtCore import QObject +from qtpy.QtGui import QTextCursor +from qtpy.QtWidgets import ( + QComboBox, + QHBoxLayout, + QLabel, + QLayout, + QSizePolicy, + QTextEdit, + QVBoxLayout, + QWidget, +) + +logger = logging.getLogger(__name__) + + +### -------------- UI utilities -------------- ### +class ContainerWidget(QWidget): + """Class for a container widget that can contain other widgets.""" + + def __init__( + self, l=0, t=0, r=1, b=11, vertical=True, parent=None, fixed=True + ): + """Creates a container widget that can contain other widgets. + + Args: + l: left margin in pixels + t: top margin in pixels + r: right margin in pixels + b: bottom margin in pixels + vertical: if True, renders vertically. Horizontal otherwise + parent: parent QWidget + fixed: uses QLayout.SetFixedSize if True + """ + super().__init__(parent) + self.layout = None + + if vertical: + self.layout = QVBoxLayout(self) + else: + self.layout = QHBoxLayout(self) + + self.layout.setContentsMargins(l, t, r, b) + if fixed: + self.layout.setSizeConstraint(QLayout.SetFixedSize) + + +def get_time(): + """Get time in the following format : hour:minute:second. NOT COMPATIBLE with file paths (saving with ":" is invalid).""" + return f"{datetime.now():%H:%M:%S}" + + +def add_widgets(layout, widgets): + """Adds all widgets in the list to layout, with the specified alignment. + + If alignment is None, no alignment is set. + + Args: + layout: layout to add widgets in + widgets: list of QWidgets to add to layout + """ + for w in widgets: + layout.addWidget(w) + + +def make_label(name, parent=None): # TODO update to child class + """Creates a QLabel. + + Args: + name: string with name + parent: parent widget + + Returns: created label + + """ + label = QLabel(name, parent) if parent is not None else QLabel(name) + return label + + +class QWidgetSingleton(type(QObject)): + """To be used as a metaclass when making a singleton QWidget, meaning only one instance exists at a time. + + Avoids unnecessary memory overhead and keeps user parameters even when a widget is closed. + """ + + _instances = {} + + def __call__(cls, *args, **kwargs): + """Ensure only one instance of a QWidget with QWidgetSingleton as a metaclass exists at a time.""" + if cls not in cls._instances: + cls._instances[cls] = super().__call__(*args, **kwargs) + return cls._instances[cls] + + +### -------------- Tracking widgets -------------- ### + + +class DropdownMenu(QComboBox): + """Creates a dropdown menu with a title and adds specified entries to it.""" + + def __init__( + self, + entries: Optional[list] = None, + parent: Optional[QWidget] = None, + text_label: Optional[str] = None, + fixed: Optional[bool] = True, + ): + """Creates a dropdown menu with a title and adds specified entries to it. + + Args: + entries (array(str)): Entries to add to the dropdown menu. Defaults to None, no entries if None + parent (QWidget): parent QWidget to add dropdown menu to. Defaults to None, no parent is set if None + text_label (str) : if not None, creates a QLabel with the contents of 'label', and returns the label as well + fixed (bool): if True, will set the size policy of the dropdown menu to Fixed in h and w. Defaults to True. + """ + super().__init__(parent) + self.label = None + if entries is not None: + self.addItems(entries) + if text_label is not None: + self.label = QLabel(text_label) + if fixed: + self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) + + def get_items(self): + """Returns the items in the dropdown menu.""" + return [self.itemText(i) for i in range(self.count())] + + +class LayerSelecter(ContainerWidget): + """Class that creates a dropdown menu to select a layer from a napari viewer.""" + + def __init__( + self, viewer, name="Layer", layer_type=napari.layers.Layer, parent=None + ): + """Creates an instance of LayerSelecter.""" + super().__init__(parent=parent, fixed=False) + self._viewer = viewer + self.layer_type = layer_type + + self.layer_list = DropdownMenu( + parent=self, text_label=name, fixed=False + ) + # if it's a keypoint layer, show the number of keypoints + if layer_type == napari.layers.Points: + self.layer_description = make_label( + "Number of keypoints:", parent=self + ) + else: + self.layer_description = make_label("Video :", parent=self) + self.layer_description.setVisible(False) + # self.layer_list.setSizeAdjustPolicy(QComboBox.AdjustToContents) # use tooltip instead ? + + # connect to LayerList events + self._viewer.layers.events.inserted.connect(partial(self._add_layer)) + self._viewer.layers.events.removed.connect(partial(self._remove_layer)) + self._viewer.layers.events.changed.connect(self._check_for_layers) + + # update self.layer_list when layers are added or removed + self.layer_list.currentIndexChanged.connect(self._update_tooltip) + self.layer_list.currentTextChanged.connect(self._update_description) + + add_widgets( + self.layout, + [self.layer_list.label, self.layer_list, self.layer_description], + ) + self._check_for_layers() + + def _get_all_layers(self): + return [ + self.layer_list.itemText(i) for i in range(self.layer_list.count()) + ] + + def _check_for_layers(self): + """Check for layers of the correct type and update the dropdown menu. + + Also removes layers that have been removed from the viewer. + """ + for layer in self._viewer.layers: + layer.events.name.connect(self._rename_layer) + + if ( + isinstance(layer, self.layer_type) + and layer.name not in self._get_all_layers() + ): + logger.debug( + f"Layer {layer.name} - List : {self._get_all_layers()}" + ) + # add new layers of correct type + self.layer_list.addItem(layer.name) + logger.debug(f"Layer {layer.name} has been added to the menu") + # break + # once added, check again for previously renamed layers + self._check_for_removed_layer(layer) + + if layer.name in self._get_all_layers() and not isinstance( + layer, self.layer_type + ): + # remove layers of incorrect type + index = self.layer_list.findText(layer.name) + self.layer_list.removeItem(index) + logger.debug( + f"Layer {layer.name} has been removed from the menu" + ) + + self._check_for_removed_layers() + self._update_tooltip() + self._update_description() + + def _check_for_removed_layer(self, layer): + """Check if a specific layer has been removed from the viewer and must be removed from the menu.""" + if isinstance(layer, str): + name = layer + elif isinstance(layer, self.layer_type): + name = layer.name + else: + logger.warning("Layer is not a string or a valid napari layer") + return + + if name in self._get_all_layers() and name not in [ + l.name for l in self._viewer.layers + ]: + index = self.layer_list.findText(name) + self.layer_list.removeItem(index) + logger.debug(f"Layer {name} has been removed from the menu") + + def _check_for_removed_layers(self): + """Check for layers that have been removed from the viewer and must be removed from the menu.""" + for layer in self._get_all_layers(): + self._check_for_removed_layer(layer) + + def _update_tooltip(self): + self.layer_list.setToolTip(self.layer_list.currentText()) + + def _update_description(self): + try: + if self.layer_list.currentText() != "": + try: + if self.layer_type == napari.layers.Points: + shape_desc = f"{len(self.layer_data())} keypoints" + else: + shape_desc = f"{self.layer_data().shape} frames" + self.layer_description.setText(shape_desc) + self.layer_description.setVisible(True) + except AttributeError: + self.layer_description.setVisible(False) + else: + self.layer_description.setVisible(False) + except KeyError: + self.layer_description.setVisible(False) + + def _add_layer(self, event): + inserted_layer = event.value + + if isinstance(inserted_layer, self.layer_type): + self.layer_list.addItem(inserted_layer.name) + + # check for renaming + inserted_layer.events.name.connect(self._rename_layer) + + def _rename_layer(self, _): + # on layer rename, check for removed/new layers + self._check_for_layers() + + def _remove_layer(self, event): + removed_layer = event.value + + if isinstance( + removed_layer, self.layer_type + ) and removed_layer.name in [ + self.layer_list.itemText(i) for i in range(self.layer_list.count()) + ]: + index = self.layer_list.findText(removed_layer.name) + self.layer_list.removeItem(index) + + def layer(self): + """Returns the layer selected in the dropdown menu.""" + try: + return self._viewer.layers[self.layer_name()] + except ValueError: + return None + + def layer_name(self): + """Returns the name of the layer selected in the dropdown menu.""" + try: + return self.layer_list.currentText() + except (KeyError, ValueError): + logger.warning("Layer list is empty") + return None + + def layer_data(self): + """Returns the data of the layer selected in the dropdown menu.""" + if self.layer_list.count() < 1: + logger.debug("Layer list is empty") + return None + try: + if self.layer_type == napari.layers.Points: + return self.layer().data + else: + return self.layer().data + except (KeyError, ValueError): + msg = f"Layer {self.layer_name()} has no data. Layer might have been renamed or removed." + logger.warning(msg) + return None + + +class Log(QTextEdit): + """Class to implement a log for important user info. Should be thread-safe.""" + + def __init__(self, parent=None): + """Creates a log with a lock for multithreading. + + Args: + parent (QWidget): parent widget to add Log instance to. + """ + super().__init__(parent) + + # from qtpy.QtCore import QMetaType + # parent.qRegisterMetaType("QTextCursor") + + self.lock = threading.Lock() + + def flush(self): + """Flush the log.""" + + def write(self, message): + """Write message to log in a thread-safe manner. + + Args: + message: string to be printed + """ + self.lock.acquire() + try: + if not hasattr(self, "flag"): + self.flag = False + message = message.replace("\r", "").rstrip() + if message: + method = "replace_last_line" if self.flag else "append" + QtCore.QMetaObject.invokeMethod( + self, + method, + QtCore.Qt.QueuedConnection, + QtCore.Q_ARG(str, message), + ) + self.flag = True + else: + self.flag = False + + finally: + self.lock.release() + + @QtCore.Slot(str) + def replace_last_line(self, text): + """Replace last line. For use in progress bar. + + Args: + text: string to be printed + """ + self.lock.acquire() + try: + cursor = self.textCursor() + cursor.movePosition(QTextCursor.End) + cursor.select(QTextCursor.BlockUnderCursor) + cursor.removeSelectedText() + cursor.insertBlock() + self.setTextCursor(cursor) + self.insertPlainText(text) + finally: + self.lock.release() + + def print_and_log(self, text, printing=True): + """Utility used to both print to terminal and log text to a QTextEdit item in a thread-safe manner. Use only for important user info. + + Args: + text (str): Text to be printed and logged + printing (bool): Whether to print the message as well or not using logger.info(). Defaults to True. + + """ + self.lock.acquire() + try: + if printing: + logger.info(text) + # causes issue if you clik on terminal (tied to CMD QuickEdit mode on Windows) + self.moveCursor(QTextCursor.End) + self.insertPlainText(f"\n{text}") + self.verticalScrollBar().setValue( + self.verticalScrollBar().maximum() + ) + finally: + self.lock.release() + + def warn(self, warning): + """Show logger.warning from another thread. + + Args: + warning: warning to be printed + """ + self.lock.acquire() + try: + logger.warning(warning) + finally: + self.lock.release() + + def error(self, error, msg=None): + """Show exception and message from another thread. + + Args: + error: error to be printed + msg: message to be printed + """ + self.lock.acquire() + try: + logger.error(error, exc_info=True) + if msg is not None: + self.print_and_log(f"{msg} : {error}", printing=False) + else: + self.print_and_log( + f"Exception caught in another thread : {error}", + printing=False, + ) + raise error + finally: + self.lock.release() diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py new file mode 100644 index 0000000..a1f844f --- /dev/null +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -0,0 +1,665 @@ +# TODO : +# - Implement the tracking worker with multithreading +# - Implement a Log widget to display the tracking progress + a progress bar +# - Prepare I/O with the actual tracking backend +from functools import partial +from pathlib import Path +from dataclasses import dataclass +import napari +import numpy as np +import pandas as pd +import torch +from cotracker.predictor import CoTrackerOnlinePredictor +from napari._qt.qthreading import GeneratorWorker +from qtpy.QtCore import Signal +from qtpy.QtWidgets import ( + QProgressBar, + QPushButton, + QSizePolicy, + QVBoxLayout, + QWidget, +) +from superqt.utils._qthreading import GeneratorWorkerSignals, WorkerBaseSignals + +from napari_deeplabcut._tracking_utils import ( + ContainerWidget, + LayerSelecter, + Log, + QWidgetSingleton, + add_widgets, + get_time, +) + + +def get_track_filename(start_frame: int) -> str: + return f"TrackedKeypoints_start_{start_frame}.h5" + + +@dataclass +class TrackingConfig: + ### Data ### + video: np.ndarray + keypoints: np.ndarray + # result_layer: napari.layers.Points + ### Metadata ### + root: str # path to the video + paths: list # list of paths to the video frames + bodyparts: list # list of bodyparts + individuals_ids: list # list of individuals + ### Config from data ### + n_frames: int + n_animals: int + n_keypoints: int + ### User config ### + retrack_frame_id: int = None + method : str = "CoTracker" # change when adding PIPS++ + device: str = "cpu" if not torch.cuda.is_available() else "cuda" + +@dataclass +class TrackingResults: # Add anything relevant to be yielded by the worker here + """Used to update the results and progress bar. Is yielded by the worker.""" + result_keypoints: np.ndarray = None + layer_metadata: dict = None + hdf_path : str = None + pbar_update : tuple = None + +class TrackingModule(QWidget, metaclass=QWidgetSingleton): + """Plugin for tracking.""" + + def __init__(self, napari_viewer: "napari.viewer.Viewer", parent=None): + """Creates a widget with links to documentation and about page.""" + super().__init__(parent=parent) + self._viewer = napari_viewer + self._worker = None + self._video_layer : napari.layers.Image = None + self._keypoint_layer : napari.layers.Points = None + self.result_layer : napari.layers.Points = None + ### Widgets ### + self.video_layer_dropdown = LayerSelecter( + self._viewer, + name="Video layer", + layer_type=napari.layers.Image, + parent=self, + ) + self.keypoint_layer_dropdown = LayerSelecter( + self._viewer, + name="Keypoint layer", + layer_type=napari.layers.Points, + parent=self, + ) + self.start_button = QPushButton("Start tracking") + self.start_button.clicked.connect(self._start) + self.enable_retracking = False + self._check_for_retracking_availability() + self._viewer.layers.events.inserted.connect(self._check_for_retracking_availability) + + ############################# + # status report docked widget + self.container_docked = False # check if already docked + + self.report_container = ContainerWidget(l=10, t=5, r=5, b=5) + + self.report_container.setSizePolicy( + QSizePolicy.Fixed, QSizePolicy.Minimum + ) + self.progress = QProgressBar(self.report_container) + self.progress.setVisible(False) + """Widget for the progress bar""" + + self.log = Log(self.report_container) + self.log.setVisible(False) + """Read-only display for process-related info. Use only for info destined to user.""" + self._build() + self._viewer.dims.events.current_step.connect(self._update_start_button_display) + + # Use @property to get/set the keypoint layer + @property + def keypoint_layer(self): + """Get the keypoint layer.""" + return self._keypoint_layer + + @keypoint_layer.setter + def keypoint_layer(self, layer_name): + """Set the keypoint layer from the viewer.""" + for l in self._viewer.layers: + if l.name == layer_name: + self._keypoint_layer = l + break + + def _build(self): + """Create a TrackingModule plugin with : + + - A dropdown menu to select the keypoint layer + - A set of keypoints to track + - A button to start tracking + - A Log that shows when starting. providing feedback on the tracking process + """ + layout = QVBoxLayout() + + widgets = [ + self.video_layer_dropdown, + self.keypoint_layer_dropdown, + self.start_button, + ] + add_widgets(layout, widgets) + self.setLayout(layout) + + def _check_ready(self): + """Check if the inputs are ready for tracking.""" + if self.video_layer_dropdown.layer is None: + return False + if self.keypoint_layer_dropdown.layer is None: + return False + return True + + def _display_status_report(self): + """Adds a text log, a progress bar and a "save log" button on the left side of the viewer (usually when starting a worker).""" + + if self.container_docked: + self.log.clear() + elif not self.container_docked: + add_widgets( + self.report_container.layout, + [self.progress, self.log], + ) + self.report_container.setLayout(self.report_container.layout) + report_dock = self._viewer.window.add_dock_widget( + self.report_container, + name="Status report", + area="left", + allowed_areas=["left", "right"], + ) + report_dock._close_btn = False + + # self.docked_widgets.append(report_dock) + self.container_docked = True + + self.log.setVisible(True) + # self.progress.setVisible(True) + # self.progress.setValue(0) + + def _update_start_button_display(self): + """Update the start button display.""" + if self._worker is not None: + if self._worker.is_running: + return + + if self.result_layer is not None and self.enable_retracking: + if self._worker is not None: + if self._worker.is_running: + return + current_frame = self._viewer.dims.current_step[0] + if current_frame == 0: + self.start_button.setText("Start") + return + self.start_button.setText(f"Retrack from frame {current_frame}") + else: + self.start_button.setText("Start") + + def _check_for_retracking_availability(self): + base_track_filename = get_track_filename(start_frame=0) + print(base_track_filename) + base_track_stem = base_track_filename.split(".h5")[0] + print("stem", base_track_stem) + for layer in self._viewer.layers: + print(f"Layer: '{layer.name}'") + if base_track_stem in layer.name: + self.enable_retracking = True + self.result_layer = layer + self._update_start_button_display() + + def _update_progress_bar(self, current_frame, total_frame): + """Update the progress bar.""" + pbar_value = (current_frame / total_frame) * 100 + if pbar_value > 100: + pbar_value = 100 + + self.progress.setValue(pbar_value) + + def _start(self): + """Start the tracking process.""" + print("Started tracking") + + ### Below is code to start the worker and update the button for the use to start/stop the tracking process + if not self._check_ready(): + err = "Aborting, please choose valid inputs" + self.log.print_and_log(err) + raise ValueError(err) + + if self._worker is not None: + if self._worker.is_running: + pass + else: + self._worker.start() + self.start_button.setText("Running... Click to stop") + + else: + self.log.print_and_log("Starting...") + self.log.print_and_log("*" * 20) + self._setup_worker() + + if self._worker.is_running: # if worker is running, tries to stop + self.log.print_and_log( + "Stop request, waiting for next inference..." + ) + self.start_button.setText("Stopping...") + self._worker.quit() + else: # once worker is started, update buttons + self._worker.start() + self.start_button.setText("Running... Click to stop") + + def _setup_worker(self): + metadata = self.keypoint_layer_dropdown.layer().metadata + keypoint_cord = self.keypoint_layer_dropdown.layer_data() + frames = self.video_layer_dropdown.layer_data() + + header = metadata["header"] + bodyparts = header.bodyparts + individuals_ids = header.individuals + + current_frame = self._viewer.dims.current_step[0] + retrack_frame_id = None + if current_frame != 0 and self.enable_retracking: + retrack_frame_id = current_frame + + self.worker_config = TrackingConfig( + video=frames, + keypoints=keypoint_cord, + root=metadata["root"], + paths=metadata["paths"], + bodyparts=bodyparts, + individuals_ids=individuals_ids, + n_frames=len(frames), + n_animals=len(individuals_ids), + n_keypoints=len(bodyparts), + retrack_frame_id=retrack_frame_id, + ) + self._worker = TrackingWorker(self.worker_config) + + + # self._worker = TrackingWorker( + # # metadata["metadata"]["root"], + # # metadata["metadata"]["images"], + # metadata["root"], + # metadata["paths"], + # bodyparts, + # individuals_ids, + # frames, + # keypoint_cord, + # ) + + self._worker.started.connect(self._on_start) + + self._worker.log_signal.connect(self.log.print_and_log) + self._worker.log_w_replace_signal.connect(self.log.replace_last_line) + self._worker.warn_signal.connect(self.log.warn) + self._worker.error_signal.connect(self.log.error) + + self._worker.yielded.connect(partial(self._on_yield)) + self._worker.errored.connect(partial(self._on_error)) + self._worker.finished.connect(self._on_finish) + + def _display_results(self, results): + """Display the results in the viewer, using the method already implemented in the viewer.""" + # path_test = "C:/Users/Cyril/Desktop/Code/DeepLabCut/examples/openfield-Pranav-2018-10-30/labeled-data/m4s1/CollectedData_Pranav.h5" + from napari_deeplabcut._reader import read_hdf + + path_test = str(results) + keypoint_data, metadata, _ = read_hdf(path_test)[0] + # hdf data contains : keypoint data, metadata, and "points" + # we want to create a points layer from the keypoint data + # layer properties (dict) should be populated with metadata + print(metadata) + frame_id = self._worker.config.retrack_frame_id if self._worker.config.retrack_frame_id is not None else 0 + return self._viewer.add_points( + ### data ### + keypoint_data, + name=results.stem, + metadata=metadata["metadata"], + # features=metadata["properties"], + properties=metadata["properties"], + ### display properties ### + face_color=metadata["face_color"], + face_color_cycle=metadata["face_color_cycle"], + face_colormap=metadata["face_colormap"], + edge_color=metadata["edge_color"], + edge_color_cycle=metadata["edge_color_cycle"], + edge_width=metadata["edge_width"], + edge_width_is_relative=metadata["edge_width_is_relative"], + size=metadata["size"], + ) + + + def _on_yield(self, results): + # TODO : display the results in the viewer + # Testing version where an int i is yielded + self.result_layer = self._display_results(results) + ############################ + self.log.print_and_log(f"Yielded {results}") + # self._update_progress_bar(results, 10) + ############################ + + def _on_start(self): + """Catches start signal from worker to call :py:func:`~display_status_report`.""" + self._display_status_report() + self.log.print_and_log(f"Worker started at {get_time()}") + self.log.print_and_log("Worker is running...") + + def _on_error(self, error): + """Catches errors and tries to clean up.""" + self.log.print_and_log("!" * 20) + self.log.print_and_log("Worker errored...") + self.log.error(error) + self._worker.quit() + self.on_finish() + + def _on_finish(self): + """Catches finished signal from worker, resets workspace for next run.""" + self.log.print_and_log(f"\nWorker finished at {get_time()}") + self.log.print_and_log("*" * 20) + self.start_button.setText("Start") + self._update_start_button_display() + + self._worker = None + self.enable_retracking = True + + return True # signal clean exit + + +### -------- Tracking worker -------- ### + + +class LogSignal(WorkerBaseSignals): + """Signal to send messages to be logged from another thread. + + Separate from Worker instances as indicated `on this post`_ + + .. _on this post: https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect + """ # TODO link ? + + log_signal = Signal(str) + """qtpy.QtCore.Signal: signal to be sent when some text should be logged""" + log_w_replace_signal = Signal(str) + """qtpy.QtCore.Signal: signal to be sent when some text should be logged, replacing the last line""" + warn_signal = Signal(str) + """qtpy.QtCore.Signal: signal to be sent when some warning should be emitted in main thread""" + error_signal = Signal(Exception, str) + """qtpy.QtCore.Signal: signal to be sent when some error should be emitted in main thread""" + + # Should not be an instance variable but a class variable, not defined in __init__, see + # https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect + + def __init__(self, parent=None): + """Creates a LogSignal.""" + super().__init__(parent=parent) + + +class TrackingWorker(GeneratorWorker): + """A custom worker to run tracking in.""" + + def __init__(self, config: TrackingConfig): + """Creates a TrackingWorker.""" + super().__init__(self.run_tracking) + + self.config = config + + self._root = config.root + self._image_paths = config.paths + self._bodyparts = config.bodyparts + self._individuals = config.individuals_ids + self._video = config.video + self._keypoints = config.keypoints + + self._signals = LogSignal() + self.log_signal = self._signals.log_signal + self.log_w_replace_signal = self._signals.log_w_replace_signal + self.warn_signal = self._signals.warn_signal + self.error_signal = self._signals.error_signal + + def log(self, msg): + """Log a message.""" + self.log_signal.emit(msg) + + def log_w_replace(self, msg): + """Log a message, replacing the last line. For us with progress bars mainly.""" + self.log_w_replace_signal.emit(msg) + + def warn(self, msg): + """Log a warning.""" + self.warn_signal.emit(msg) + + def run_tracking(self): + """Run the tracking.""" + with open("log.txt", "w") as f: + f.write(f"{self._video.shape}") + f.write(f"{self._keypoints.shape}") + + retrack_frame = self.config.retrack_frame_id if self.config.retrack_frame_id is not None else 0 + self.log(f"Started tracking from frame {retrack_frame}") + + + video_frames = np.array(self._video[retrack_frame:]) + + keypoints = np.array(self._keypoints) + with open("log_frame_indices.txt", "w") as f: + f.write(f"{keypoints[:, 0]}\n\n") + f.write(f"{keypoints}") + + keypoints = keypoints[keypoints[:, 0] == retrack_frame][:, [2, 1]] + keypoints = keypoints.reshape((len(self._individuals), len(self._bodyparts), 2)) + + tracks = cotrack_online( + self.log, + video_frames, + keypoints, + device=self.config.device, + ) + with open("log_finished_tracking.txt", "w") as f: + f.write(f"Done! {tracks.shape}") + self.log("Finished tracking") + retrack_frame = 0 if retrack_frame is None else retrack_frame + track_path = Path(self._root) / get_track_filename(start_frame=retrack_frame) + self.save_tracking_data(track_path, tracks, "CoTracker", frame=retrack_frame) + # self.create_combined_tracking_data() + self.log("Finished saving") + yield track_path + + def save_tracking_data(self, path: Path, tracks: np.ndarray, scorer: str, frame: int = 0) -> None: + levels = ["scorer", "individuals", "bodyparts", "coords"] + kpt_entries = ["x", "y"] + columns = [] + for i in self._individuals: + for b in self._bodyparts: + columns += [(scorer, i, b, entry) for entry in kpt_entries] + # for i, b in zip(self._individuals[:8], self._bodyparts[:8]): + # columns += [(scorer, i, b, entry) for entry in kpt_entries] + + index = [] + for img_path in self._image_paths[frame:]: + if isinstance(img_path, str): + index.append(tuple(Path(img_path).parts)) + elif isinstance(img_path, tuple): + index.append(img_path) + else: + raise ValueError(f"Incorrect image path format: {img_path}") + + with open("log_df.txt", "w") as f: + f.write(f"{tracks.reshape((len(tracks), -1)).shape}\n") + f.write(f"{len(index)}\n") + f.write(f"{len(columns)}\n") + f.write(f"{self._individuals}\n") + f.write(f"{self._bodyparts}\n") + + dataframe = pd.DataFrame( + data=tracks.reshape((len(tracks), -1)), + index=pd.MultiIndex.from_tuples(index), + columns=pd.MultiIndex.from_tuples(columns, names=levels), + ) + dataframe.to_hdf(path, key="df_with_missing") + + def create_combined_tracking_data(self): + track_files = [ + p for p in Path(self._root).iterdir() + if p.is_file() and p.stem.startswith("TrackedKeypoints") + ] + if len(track_files) <= 1: + return + + df_tracks = { + int(p.stem.split("_")[-1]): pd.read_hdf(p, key="df_with_missing") + for p in track_files + } + df = df_tracks[0] + indices = [i for i in sorted(df_tracks.keys()) if i > 0] + for idx in indices: + new_data = df_tracks[idx] + df.iloc[idx:] = new_data.iloc[idx:] + + path = Path(self._root) / "CombinedTracks.h5" + df.to_hdf(path, key="df_with_missing") + + def fake_tracking(self): + """Fake tracking for testing purposes.""" + for i in range(1): + self.log(f"Tracking frame {i}") + yield i + 1 + + +# TODO: REQUIRES TO RUN pip install src/co-tracker +def cotrack_online( + log, + video, + keypoints, + device: str = "cpu", +) -> np.ndarray: + log("Running CoTracker") + with open("log_cotrack.txt", "w") as f: + f.write(f"video={video.shape}\n") + f.write(f"keypoints={keypoints.shape}\n") + f.write(f"{keypoints}\n") + + def _process_step(window_frames, is_first_step, queries): + with open("log_window_frames.txt", "w") as f: + f.write(f"{len(window_frames)}\n") + f.write(f"{model.step}\n") + f.write(f"{-model.step * 2}\n") + f.write(f"is_first_step={is_first_step}\n") + + video_chunk = ( + torch.tensor(np.stack(window_frames[-model.step * 2:]), device=device) + .float() + .permute(0, 3, 1, 2)[None] + ) # (1, T, 3, H, W) + return model(video_chunk, is_first_step=is_first_step, queries=queries[None]) + + n_frames = len(video) + n_animals, n_keypoints = keypoints.shape[:2] + + model = torch.hub.load("facebookresearch/co-tracker", "cotracker2_online") + model = model.to(device) + video = torch.from_numpy(video).permute(0, 3, 1, 2).unsqueeze(0).float() + window_frames = [] + + queries = np.zeros((n_animals * n_keypoints, 3)) + queries[:, 1:] = keypoints.reshape((-1, 2)) + queries = torch.from_numpy(queries).to(device).float() + + # Iterating over video frames, processing one window at a time: + is_first_step = True + i = 0 + for i, frame in enumerate(video[0]): + frame = frame.permute(1, 2, 0) + if i % model.step == 0 and i != 0: + pred_tracks, pred_visibility = _process_step( + window_frames, + is_first_step, + queries=queries, + ) + is_first_step = False + window_frames.append(frame) + log(f"Finished frame {i}") + + # Processing final frames in case video length is not a multiple of model.step + # TODO: Use visibility + pred_tracks, pred_visibility = _process_step( + window_frames[-(i % model.step) - model.step - 1:], + is_first_step, + queries=queries, + ) + + with open("log_pred_tracks.txt", "w") as f: + f.write(f"{len(pred_tracks)}\n") + f.write(f"{pred_tracks.shape}\n") + + tracks = pred_tracks.squeeze().cpu().numpy() + with open("log_pred_tracks_2.txt", "w") as f: + f.write(f"{len(tracks)}\n") + f.write(f"{tracks.shape}\n") + f.write(f"{(n_frames, n_animals, n_keypoints, 2)}\n") + + return tracks.reshape((n_frames, n_animals, n_keypoints, 2)) + + +def track_mock( + video: np.ndarray, + keypoints: np.ndarray, +) -> np.ndarray: + """Mocks what a tracker would do. + + This method's signature should be re-used by all trackers (PIPS++ and CoTracker). + + Args: + video: The path to a video in which the points should be tracked. + keypoints: The position of keypoints to track in the video. This array should + have shape (n_animals, n_keypoints, 2), where + n_animals: the number of animals to track + n_keypoints: the number of keypoints to track for each individual + 2: as each point is defined by its (x, y) coordinates + + Returns: + an array of shape (num_frames, n_animals, n_keypoints, 2) corresponding to the + position of each keypoint in each frame of the video + """ + return np.repeat(keypoints, (len(video), 1, 1, 1)) + + +def track_cotracker( + video: Path, + keypoints: np.ndarray, +) -> np.ndarray: + """Tracks keypoints in a video using CoTracker. + + Args: + video: The path to a video in which the points should be tracked. + keypoints: The position of keypoints to track in the video. This array should + have shape (n_animals, n_keypoints, 2), where + n_animals: the number of animals to track + n_keypoints: the number of keypoints to track for each individual + 2: as each point is defined by its (x, y) coordinates + + Returns: + an array of shape (num_frames, n_animals, n_keypoints, 2) corresponding to the + position of each keypoint in each frame of the video + """ + # TODO: Implement your code here! + + +def track_pips( + video: Path, + keypoints: np.ndarray, +) -> np.ndarray: + """Tracks keypoints in a video using PIPS++. + + Args: + video: The path to a video in which the points should be tracked. + keypoints: The position of keypoints to track in the video. This array should + have shape (n_animals, n_keypoints, 2), where + n_animals: the number of animals to track + n_keypoints: the number of keypoints to track for each individual + 2: as each point is defined by its (x, y) coordinates + + Returns: + an array of shape (num_frames, n_animals, n_keypoints, 2) corresponding to the + position of each keypoint in each frame of the video + """ + # TODO: Implement your code here! diff --git a/src/napari_deeplabcut/_widgets.py b/src/napari_deeplabcut/_widgets.py index d1512cc..2fcf430 100644 --- a/src/napari_deeplabcut/_widgets.py +++ b/src/napari_deeplabcut/_widgets.py @@ -3,19 +3,21 @@ from collections import defaultdict, namedtuple from copy import deepcopy from datetime import datetime -from functools import partial, cached_property +from functools import cached_property, partial from math import ceil, log10 -import matplotlib.pyplot as plt -import matplotlib.style as mplstyle -import napari -import pandas as pd from pathlib import Path from types import MethodType from typing import Optional, Sequence, Union -from matplotlib.backends.backend_qtagg import FigureCanvas, NavigationToolbar2QT - +import matplotlib.pyplot as plt +import matplotlib.style as mplstyle +import napari import numpy as np +import pandas as pd +from matplotlib.backends.backend_qtagg import ( + FigureCanvas, + NavigationToolbar2QT, +) from napari._qt.widgets.qt_welcome import QtWelcomeLabel from napari.layers import Image, Points, Shapes, Tracks from napari.layers.points._points_key_bindings import register_points_action @@ -23,8 +25,8 @@ from napari.layers.utils.layer_utils import _features_to_properties from napari.utils.events import Event from napari.utils.history import get_save_history, update_save_history -from qtpy.QtCore import Qt, QTimer, Signal, QPoint, QSettings, QSize -from qtpy.QtGui import QPainter, QAction, QCursor, QIcon +from qtpy.QtCore import QPoint, QSettings, QSize, Qt, QTimer, Signal +from qtpy.QtGui import QAction, QCursor, QIcon, QPainter from qtpy.QtSvgWidgets import QSvgWidget from qtpy.QtWidgets import ( QButtonGroup, @@ -50,16 +52,19 @@ from napari_deeplabcut import keypoints from napari_deeplabcut._reader import _load_config -from napari_deeplabcut._writer import _write_config, _write_image, _form_df +from napari_deeplabcut._tracking_worker import TrackingModule +from napari_deeplabcut._writer import _form_df, _write_config, _write_image from napari_deeplabcut.misc import ( + build_color_cycles, encode_categories, - to_os_dir_sep, guarantee_multiindex_rows, - build_color_cycles, + to_os_dir_sep, ) Tip = namedtuple("Tip", ["msg", "pos"]) +# enable debug logging +logging.basicConfig(level=logging.DEBUG) class Shortcuts(QDialog): """Opens a window displaying available napari-deeplabcut shortcuts""" @@ -69,7 +74,9 @@ def __init__(self, parent): self.setParent(parent) self.setWindowTitle("Shortcuts") - image_path = str(Path(__file__).parent / "assets" / "napari_shortcuts.svg") + image_path = str( + Path(__file__).parent / "assets" / "napari_shortcuts.svg" + ) vlayout = QVBoxLayout() svg_widget = QSvgWidget(image_path) @@ -118,7 +125,9 @@ def __init__(self, parent): ] vlayout = QVBoxLayout() - self.message = QLabel("💡\n\nLet's get started with a quick walkthrough!") + self.message = QLabel( + "💡\n\nLet's get started with a quick walkthrough!" + ) self.message.setTextInteractionFlags(Qt.LinksAccessibleByMouse) self.message.setOpenExternalLinks(True) vlayout.addWidget(self.message) @@ -194,10 +203,7 @@ def _get_and_try_preferred_reader( # where property is understood as continuous if # there are more than 16 unique categories... def guess_continuous(property): - if issubclass(property.dtype.type, np.floating): - return True - else: - return False + return bool(issubclass(property.dtype.type, np.floating)) color_manager.guess_continuous = guess_continuous @@ -236,12 +242,17 @@ def _paste_data(self, store): not_disp = self._slice_input.not_displayed data = deepcopy(self._clipboard["data"]) offset = [ - self._slice_indices[i] - self._clipboard["indices"][i] for i in not_disp + self._slice_indices[i] - self._clipboard["indices"][i] + for i in not_disp ] data[:, not_disp] = data[:, not_disp] + np.array(offset) self._data = np.append(self.data, data, axis=0) - self._shown = np.append(self.shown, deepcopy(self._clipboard["shown"]), axis=0) - self._size = np.append(self.size, deepcopy(self._clipboard["size"]), axis=0) + self._shown = np.append( + self.shown, deepcopy(self._clipboard["shown"]), axis=0 + ) + self._size = np.append( + self.size, deepcopy(self._clipboard["size"]), axis=0 + ) self._symbol = np.append( self.symbol, deepcopy(self._clipboard["symbol"]), axis=0 ) @@ -405,6 +416,8 @@ def __init__(self, napari_viewer, parent=None): layout.addLayout(layout2) self.setLayout(layout) + ############################ + self.frames = [] self.keypoints = [] self.df = None @@ -497,7 +510,9 @@ def _load_dataframe(self): if points_layer is None or ~np.any(points_layer.data): return - self.viewer.window.add_dock_widget(self, name="Trajectory plot", area="right") + self.viewer.window.add_dock_widget( + self, name="Trajectory plot", area="right" + ) self.hide() self.df = _form_df( @@ -508,9 +523,13 @@ def _load_dataframe(self): }, ) for keypoint in self.df.columns.get_level_values("bodyparts").unique(): - y = self.df.xs((keypoint, "y"), axis=1, level=["bodyparts", "coords"]) + y = self.df.xs( + (keypoint, "y"), axis=1, level=["bodyparts", "coords"] + ) x = np.arange(len(y)) - color = points_layer.metadata["face_color_cycles"]["label"][keypoint] + color = points_layer.metadata["face_color_cycles"]["label"][ + keypoint + ] lines = self.ax.plot(x, y, color=color, label=keypoint) self._lines[keypoint] = lines @@ -553,9 +572,11 @@ def __init__(self, napari_viewer): self.viewer.layers.events.inserted.connect(self.on_insert) self.viewer.layers.events.removed.connect(self.on_remove) - self.viewer.window.qt_viewer._get_and_try_preferred_reader = MethodType( - _get_and_try_preferred_reader, - self.viewer.window.qt_viewer, + self.viewer.window.qt_viewer._get_and_try_preferred_reader = ( + MethodType( + _get_and_try_preferred_reader, + self.viewer.window.qt_viewer, + ) ) status_bar = self.viewer.window._qt_window.statusBar() @@ -593,7 +614,9 @@ def __init__(self, napari_viewer): self._layout = QVBoxLayout(self) self._menus = [] self._layer_to_menu = {} - self.viewer.layers.selection.events.active.connect(self.on_active_layer_change) + self.viewer.layers.selection.events.active.connect( + self.on_active_layer_change + ) self._video_group = self._form_video_action_menu() self.video_widget = self.viewer.window.add_dock_widget( @@ -635,9 +658,14 @@ def __init__(self, napari_viewer): self._radio_group = self._form_mode_radio_buttons() # form color scheme display + color mode selector - self._color_grp, self._color_mode_selector = self._form_color_mode_selector() + ( + self._color_grp, + self._color_mode_selector, + ) = self._form_color_mode_selector() self._display = ColorSchemeDisplay(parent=self) - self._color_scheme_display = self._form_color_scheme_display(self.viewer) + self._color_scheme_display = self._form_color_scheme_display( + self.viewer + ) self._view_scheme_cb.toggled.connect(self._show_color_scheme) self._view_scheme_cb.toggle() self._display.added.connect( @@ -646,6 +674,10 @@ def __init__(self, napari_viewer): ), ) + # create a QtGroup for the tracking module + self._tracking_group = self._form_tracking_module() + self._layout.addWidget(self._tracking_group) + # Substitute default menu action with custom one for action in self.viewer.window.file_menu.actions()[::-1]: action_name = action.text().lower() @@ -674,8 +706,12 @@ def __init__(self, napari_viewer): self.viewer.window._qt_viewer.viewerButtons.gridViewButton.hide() self.viewer.window._qt_viewer.viewerButtons.rollDimsButton.hide() self.viewer.window._qt_viewer.viewerButtons.transposeDimsButton.hide() - self.viewer.window._qt_viewer.layerButtons.newPointsButton.setDisabled(True) - self.viewer.window._qt_viewer.layerButtons.newLabelsButton.setDisabled(True) + self.viewer.window._qt_viewer.layerButtons.newPointsButton.setDisabled( + True + ) + self.viewer.window._qt_viewer.layerButtons.newLabelsButton.setDisabled( + True + ) if self.settings.value("first_launch", True) and not os.environ.get( "hide_tutorial", False @@ -755,6 +791,14 @@ def _form_help_buttons(self): layout.addWidget(tutorial) return layout + def _form_tracking_module(self): + group_box = QGroupBox("Tracking") + tracking = TrackingModule(napari_viewer=self.viewer, parent=self) + layout = QVBoxLayout() + layout.addWidget(tracking) + group_box.setLayout(layout) + return group_box + def _extract_single_frame(self, *args): image_layer = None points_layer = None @@ -780,8 +824,10 @@ def _extract_single_frame(self, *args): "properties": points_layer.properties, }, ) - df = df.iloc[ind: ind + 1] - df.index = pd.MultiIndex.from_tuples([Path(output_path).parts[-3:]]) + df = df.iloc[ind : ind + 1] + df.index = pd.MultiIndex.from_tuples( + [Path(output_path).parts[-3:]] + ) filepath = os.path.join( image_layer.metadata["root"], "machinelabels-iter0.h5" ) @@ -811,7 +857,9 @@ def _store_crop_coordinates(self, *args): config_path = os.path.join(project_path, "config.yaml") cfg = _load_config(config_path) cfg["video_sets"][ - os.path.join(project_path, "videos", self._images_meta["name"]) + os.path.join( + project_path, "videos", self._images_meta["name"] + ) ] = temp _write_config(config_path, cfg) break @@ -875,7 +923,10 @@ def _form_color_scheme_display(self, viewer): def _update_color_scheme(self): def to_hex(nparray): a = np.array(nparray * 255, dtype=int) - rgb2hex = lambda r, g, b, _: f"#{r:02x}{g:02x}{b:02x}" + + def rgb2hex(r, g, b, _): + return f"#{r:02x}{g:02x}{b:02x}" + res = rgb2hex(*a) return res @@ -889,7 +940,9 @@ def to_hex(nparray): self._display.update_color_scheme( { name: to_hex(color) - for name, color in layer.metadata["face_color_cycles"][mode].items() + for name, color in layer.metadata["face_color_cycles"][ + mode + ].items() } ) @@ -902,7 +955,9 @@ def _remap_frame_indices(self, layer): if paths is not None and np.any(layer.data): paths_map = dict(zip(range(len(paths)), map(to_os_dir_sep, paths))) # Discard data if there are missing frames - missing = [i for i, path in paths_map.items() if path not in new_paths] + missing = [ + i for i, path in paths_map.items() if path not in new_paths + ] if missing: if isinstance(layer.data, list): inds_to_remove = [ @@ -911,7 +966,9 @@ def _remap_frame_indices(self, layer): if verts[0, 0] in missing ] else: - inds_to_remove = np.flatnonzero(np.isin(layer.data[:, 0], missing)) + inds_to_remove = np.flatnonzero( + np.isin(layer.data[:, 0], missing) + ) layer.selected_data = inds_to_remove layer.remove_selected() for i in missing: @@ -930,6 +987,9 @@ def _remap_frame_indices(self, layer): def on_insert(self, event): layer = event.source[-1] + self.update_layer(layer, event) # this was changed while trying to update the colors of the results layer, might be changed back + + def update_layer(self, layer, event=None): logging.debug(f"Inserting Layer {layer}") if isinstance(layer, Image): paths = layer.metadata.get("paths") @@ -945,9 +1005,9 @@ def on_insert(self, event): } ) # Delay layer sorting - QTimer.singleShot( - 10, partial(self._move_image_layer_to_bottom, event.index) - ) + # QTimer.singleShot( + # 10, partial(self._move_image_layer_to_bottom, event.index) + # ) elif isinstance(layer, Points): # If the current Points layer comes from a config file, some have already # been added and the body part names are different from the existing ones, @@ -957,13 +1017,17 @@ def on_insert(self, event): keypoints_menu = self._menus[0].menus["label"] current_keypoint_set = set( - keypoints_menu.itemText(i) for i in range(keypoints_menu.count()) + keypoints_menu.itemText(i) + for i in range(keypoints_menu.count()) ) new_keypoint_set = set(new_metadata["header"].bodyparts) diff = new_keypoint_set.difference(current_keypoint_set) + print(diff) if diff: answer = QMessageBox.question( - self, "", "Do you want to display the new keypoints only?" + self, + "", + "Do you want to display the new keypoints only?", ) if answer == QMessageBox.Yes: self.viewer.layers[-2].shown = False @@ -986,9 +1050,9 @@ def on_insert(self, event): "face_color_cycles" ] _layer.face_color = "label" - _layer.face_color_cycle = new_metadata["face_color_cycles"][ - "label" - ] + _layer.face_color_cycle = new_metadata[ + "face_color_cycles" + ]["label"] _layer.events.face_color() store.layer = _layer self._update_color_scheme() @@ -1073,9 +1137,9 @@ def on_remove(self, event): def on_active_layer_change(self, event) -> None: """Updates the GUI when the active layer changes - * Hides all KeypointsDropdownMenu that aren't for the selected layer - * Sets the visibility of the "Color mode" box to True if the selected layer - is a multi-animal one, or False otherwise + * Hides all KeypointsDropdownMenu that aren't for the selected layer + * Sets the visibility of the "Color mode" box to True if the selected layer + is a multi-animal one, or False otherwise """ self._color_grp.setVisible(self._is_multianimal(event.value)) menu_idx = -1 @@ -1092,7 +1156,8 @@ def _update_colormap(self, colormap_name): for layer in self.viewer.layers.selection: if isinstance(layer, Points) and layer.metadata: face_color_cycle_maps = build_color_cycles( - layer.metadata["header"], colormap_name, + layer.metadata["header"], + colormap_name, ) layer.metadata["face_color_cycles"] = face_color_cycle_maps face_color_prop = "label" @@ -1110,9 +1175,8 @@ def cycle_through_label_modes(self, *args): @register_points_action("Change color mode") def cycle_through_color_modes(self, *args): - if ( - self._active_layer_is_multianimal() - or self.color_mode != str(keypoints.ColorMode.BODYPART) + if self._active_layer_is_multianimal() or self.color_mode != str( + keypoints.ColorMode.BODYPART ): self.color_mode = next(keypoints.ColorMode) @@ -1152,7 +1216,8 @@ def color_mode(self, mode: Union[str, keypoints.ColorMode]): if isinstance(layer, Points) and layer.metadata: layer.face_color = face_color_mode layer.face_color_cycle = layer.metadata["face_color_cycles"][ - face_color_mode] + face_color_mode + ] layer.events.face_color() for btn in self._color_mode_selector.buttons(): @@ -1191,7 +1256,9 @@ def toggle_edge_color(layer): class DropdownMenu(QComboBox): - def __init__(self, labels: Sequence[str], parent: Optional[QWidget] = None): + def __init__( + self, labels: Sequence[str], parent: Optional[QWidget] = None + ): super().__init__(parent) self.update_items(labels) @@ -1461,7 +1528,9 @@ def _format_label(label: QLabel, height: int = None, width: int = None): def _build(self): layout = QHBoxLayout() - layout.addWidget(self.color_label, alignment=Qt.AlignmentFlag.AlignLeft) + layout.addWidget( + self.color_label, alignment=Qt.AlignmentFlag.AlignLeft + ) layout.addWidget(self.part_label, alignment=Qt.AlignmentFlag.AlignLeft) self.setLayout(layout) @@ -1527,7 +1596,9 @@ def _build(self): self.setBaseSize(100, 200) self.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOn) - self.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff) + self.setHorizontalScrollBarPolicy( + Qt.ScrollBarPolicy.ScrollBarAlwaysOff + ) def add_entry(self, name, color): self.scheme_dict.update({name: color}) @@ -1538,7 +1609,9 @@ def add_entry(self, name, color): def update_color_scheme(self, new_color_scheme) -> None: logging.debug(f"Updating color scheme: {self._layout.count()} widgets") - self.scheme_dict = {name: color for name, color in new_color_scheme.items()} + self.scheme_dict = { + name: color for name, color in new_color_scheme.items() + } names = list(new_color_scheme.keys()) existing_widgets = self._layout.count() required_widgets = len(self.scheme_dict) @@ -1555,7 +1628,7 @@ def update_color_scheme(self, new_color_scheme) -> None: for i in range(max(existing_widgets - required_widgets, 0)): logging.debug(f" hiding {required_widgets + i}") if w := self._layout.itemAt(required_widgets + i).widget(): - logging.debug(f" done!") + logging.debug(" done!") w.setVisible(False) # add missing widgets @@ -1563,7 +1636,7 @@ def update_color_scheme(self, new_color_scheme) -> None: logging.debug(f" adding {existing_widgets + i}") name = names[existing_widgets + i] self.add_entry(name, self.scheme_dict[name]) - logging.debug(f" done!") + logging.debug(" done!") def reset(self): self.scheme_dict = {}