From c20ac6c1cf28ba2d4360a6b570cc55853ccc8b96 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 Apr 2024 13:55:31 +0200 Subject: [PATCH 01/34] Add first plugin draft --- .gitignore | 1 - .napari/DESCRIPTION.md | 16 +- .pre-commit-config.yaml | 31 +- pyproject.toml | 52 +++ src/napari_deeplabcut/_tracking_utils.py | 426 ++++++++++++++++++++++ src/napari_deeplabcut/_tracking_worker.py | 269 ++++++++++++++ 6 files changed, 766 insertions(+), 29 deletions(-) create mode 100644 src/napari_deeplabcut/_tracking_utils.py create mode 100644 src/napari_deeplabcut/_tracking_worker.py diff --git a/.gitignore b/.gitignore index fd2845d..73d56d3 100644 --- a/.gitignore +++ b/.gitignore @@ -82,4 +82,3 @@ venv/ # written by setuptools_scm **/_version.py - diff --git a/.napari/DESCRIPTION.md b/.napari/DESCRIPTION.md index 9183c54..4033cba 100644 --- a/.napari/DESCRIPTION.md +++ b/.napari/DESCRIPTION.md @@ -5,18 +5,18 @@ the functionality of your plugin. Its content will be rendered on your plugin's napari hub page. The sections below are given as a guide for the flow of information only, and -are in no way prescriptive. You should feel free to merge, remove, add and -rename sections at will to make this document work best for your plugin. +are in no way prescriptive. You should feel free to merge, remove, add and +rename sections at will to make this document work best for your plugin. # Description -This should be a detailed description of the context of your plugin and its +This should be a detailed description of the context of your plugin and its intended purpose. If you have videos or screenshots of your plugin in action, you should include them -here as well, to make them front and center for new users. +here as well, to make them front and center for new users. -You should use absolute links to these assets, so that we can easily display them +You should use absolute links to these assets, so that we can easily display them on the hub. The easiest way to include a video is to use a GIF, for example hosted on imgur. You can then reference this GIF as an image. @@ -59,7 +59,7 @@ anywhere, feel free to also include this information here. This section should go through step-by-step examples of how your plugin should be used. Where your plugin provides multiple dock widgets or functions, you should split these out into separate subsections for easy browsing. Include screenshots and videos -wherever possible to elucidate your descriptions. +wherever possible to elucidate your descriptions. Ideally, this section should start with minimal examples for those who just want a quick overview of the plugin's functionality, but you should definitely link out to @@ -72,8 +72,8 @@ for the majority of plugins. They will include instructions to pip install, and to install via napari itself. Most plugins can be installed out-of-the-box by just specifying the package requirements -over in `setup.cfg`. However, if your plugin has any more complex dependencies, or -requires any additional preparation before (or after) installation, you should add +over in `setup.cfg`. However, if your plugin has any more complex dependencies, or +requires any additional preparation before (or after) installation, you should add this information here. # Getting Help diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c5d1ce5..350dc89 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,35 +5,26 @@ repos: - id: check-docstring-first - id: end-of-file-fixer - id: trailing-whitespace + - id: check-yaml + - id: check-added-large-files + args: [--maxkb=5000] + - id: check-toml - repo: https://github.com/asottile/setup-cfg-fmt rev: v1.20.0 hooks: - id: setup-cfg-fmt - - repo: https://github.com/PyCQA/flake8 - rev: 4.0.1 + - repo: https://github.com/charliermarsh/ruff-pre-commit + # Ruff version. + rev: 'v0.0.262' hooks: - - id: flake8 - additional_dependencies: [flake8-typing-imports==1.7.0] - - repo: https://github.com/myint/autoflake - rev: v1.4 - hooks: - - id: autoflake - args: ["--in-place", "--remove-all-unused-imports"] - - repo: https://github.com/PyCQA/isort - rev: 5.10.1 - hooks: - - id: isort + - id: ruff + args: [ --fix, --exit-non-zero-on-fix ] - repo: https://github.com/psf/black - rev: 22.1.0 + rev: 23.3.0 hooks: - id: black - - repo: https://github.com/asottile/pyupgrade - rev: v2.31.0 - hooks: - - id: pyupgrade - args: [--py37-plus, --keep-runtime-typing] - repo: https://github.com/tlambert03/napari-plugin-checks - rev: v0.2.0 + rev: v0.3.0 hooks: - id: napari-plugin-checks # https://mypy.readthedocs.io/en/stable/introduction.html diff --git a/pyproject.toml b/pyproject.toml index f8296a1..c7460b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,6 +2,58 @@ requires = ["setuptools", "wheel", "setuptools_scm"] build-backend = "setuptools.build_meta" +[tool.ruff] +target-version = "py38" +select = [ + "E", "F", "W", + "A", + "B", + "G", + "I", + "PT", + "SIM", + "NPY", +] +# Never enforce `E501` (line length violations) and 'E741' (ambiguous variable names) +# and 'G004' (do not use f-strings in logging) +# and 'A003' (Shadowing python builtins) +# and 'F401' (imported but unused) +ignore = ["E501", "E741", "G004", "A003", "F401"] +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".git-rewrite", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", + "docs/conf.py", + "napari_cellseg3d/_tests/conftest.py", +] + +[tool.ruff.pydocstyle] +convention = "google" + +[tool.black] +line-length = 79 + +[tool.isort] +profile = "black" +line_length = 79 [tool.setuptools_scm] write_to = "src/napari_deeplabcut/_version.py" diff --git a/src/napari_deeplabcut/_tracking_utils.py b/src/napari_deeplabcut/_tracking_utils.py new file mode 100644 index 0000000..b9c628a --- /dev/null +++ b/src/napari_deeplabcut/_tracking_utils.py @@ -0,0 +1,426 @@ +### ------------- Custom widgets for tracking module -------------- ### +import logging +import threading +from functools import partial +from typing import Optional + +import napari +from qtpy import QtCore +from qtpy.QtCore import QObject +from qtpy.QtGui import QTextCursor +from qtpy.QtWidgets import ( + QComboBox, + QHBoxLayout, + QLabel, + QLayout, + QSizePolicy, + QTextEdit, + QVBoxLayout, + QWidget, +) + +logger = logging.getLogger(__name__) + + +### -------------- UI utilities -------------- ### +class ContainerWidget(QWidget): + """Class for a container widget that can contain other widgets.""" + + def __init__( + self, l=0, t=0, r=1, b=11, vertical=True, parent=None, fixed=True + ): + """Creates a container widget that can contain other widgets. + + Args: + l: left margin in pixels + t: top margin in pixels + r: right margin in pixels + b: bottom margin in pixels + vertical: if True, renders vertically. Horizontal otherwise + parent: parent QWidget + fixed: uses QLayout.SetFixedSize if True + """ + super().__init__(parent) + self.layout = None + + if vertical: + self.layout = QVBoxLayout(self) + else: + self.layout = QHBoxLayout(self) + + self.layout.setContentsMargins(l, t, r, b) + if fixed: + self.layout.setSizeConstraint(QLayout.SetFixedSize) + + +def add_widgets(layout, widgets): + """Adds all widgets in the list to layout, with the specified alignment. + + If alignment is None, no alignment is set. + + Args: + layout: layout to add widgets in + widgets: list of QWidgets to add to layout + """ + for w in widgets: + layout.addWidget(w) + + +def make_label(name, parent=None): # TODO update to child class + """Creates a QLabel. + + Args: + name: string with name + parent: parent widget + + Returns: created label + + """ + label = QLabel(name, parent) if parent is not None else QLabel(name) + return label + + +class QWidgetSingleton(type(QObject)): + """To be used as a metaclass when making a singleton QWidget, meaning only one instance exists at a time. + + Avoids unnecessary memory overhead and keeps user parameters even when a widget is closed. + """ + + _instances = {} + + def __call__(cls, *args, **kwargs): + """Ensure only one instance of a QWidget with QWidgetSingleton as a metaclass exists at a time.""" + if cls not in cls._instances: + cls._instances[cls] = super().__call__(*args, **kwargs) + return cls._instances[cls] + + +### -------------- Tracking widgets -------------- ### + + +class DropdownMenu(QComboBox): + """Creates a dropdown menu with a title and adds specified entries to it.""" + + def __init__( + self, + entries: Optional[list] = None, + parent: Optional[QWidget] = None, + text_label: Optional[str] = None, + fixed: Optional[bool] = True, + ): + """Creates a dropdown menu with a title and adds specified entries to it. + + Args: + entries (array(str)): Entries to add to the dropdown menu. Defaults to None, no entries if None + parent (QWidget): parent QWidget to add dropdown menu to. Defaults to None, no parent is set if None + text_label (str) : if not None, creates a QLabel with the contents of 'label', and returns the label as well + fixed (bool): if True, will set the size policy of the dropdown menu to Fixed in h and w. Defaults to True. + """ + super().__init__(parent) + self.label = None + if entries is not None: + self.addItems(entries) + if text_label is not None: + self.label = QLabel(text_label) + if fixed: + self.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed) + + def get_items(self): + """Returns the items in the dropdown menu.""" + return [self.itemText(i) for i in range(self.count())] + + +class LayerSelecter(ContainerWidget): + """Class that creates a dropdown menu to select a layer from a napari viewer.""" + + def __init__( + self, viewer, name="Layer", layer_type=napari.layers.Layer, parent=None + ): + """Creates an instance of LayerSelecter.""" + super().__init__(parent=parent, fixed=False) + self._viewer = viewer + self.layer_type = layer_type + + self.layer_list = DropdownMenu( + parent=self, text_label=name, fixed=False + ) + # if it's a keypoint layer, show the number of keypoints + if layer_type == napari.layers.Points: + self.layer_description = make_label( + "Number of keypoints:", parent=self + ) + else: + self.layer_description = make_label("Video :", parent=self) + self.layer_description.setVisible(False) + # self.layer_list.setSizeAdjustPolicy(QComboBox.AdjustToContents) # use tooltip instead ? + + # connect to LayerList events + self._viewer.layers.events.inserted.connect(partial(self._add_layer)) + self._viewer.layers.events.removed.connect(partial(self._remove_layer)) + self._viewer.layers.events.changed.connect(self._check_for_layers) + + # update self.layer_list when layers are added or removed + self.layer_list.currentIndexChanged.connect(self._update_tooltip) + self.layer_list.currentTextChanged.connect(self._update_description) + + add_widgets( + self.layout, + [self.layer_list.label, self.layer_list, self.layer_description], + ) + self._check_for_layers() + + def _get_all_layers(self): + return [ + self.layer_list.itemText(i) for i in range(self.layer_list.count()) + ] + + def _check_for_layers(self): + """Check for layers of the correct type and update the dropdown menu. + + Also removes layers that have been removed from the viewer. + """ + for layer in self._viewer.layers: + layer.events.name.connect(self._rename_layer) + + if ( + isinstance(layer, self.layer_type) + and layer.name not in self._get_all_layers() + ): + logger.debug( + f"Layer {layer.name} - List : {self._get_all_layers()}" + ) + # add new layers of correct type + self.layer_list.addItem(layer.name) + logger.debug(f"Layer {layer.name} has been added to the menu") + # break + # once added, check again for previously renamed layers + self._check_for_removed_layer(layer) + + if layer.name in self._get_all_layers() and not isinstance( + layer, self.layer_type + ): + # remove layers of incorrect type + index = self.layer_list.findText(layer.name) + self.layer_list.removeItem(index) + logger.debug( + f"Layer {layer.name} has been removed from the menu" + ) + + self._check_for_removed_layers() + self._update_tooltip() + self._update_description() + + def _check_for_removed_layer(self, layer): + """Check if a specific layer has been removed from the viewer and must be removed from the menu.""" + if isinstance(layer, str): + name = layer + elif isinstance(layer, self.layer_type): + name = layer.name + else: + logger.warning("Layer is not a string or a valid napari layer") + return + + if name in self._get_all_layers() and name not in [ + l.name for l in self._viewer.layers + ]: + index = self.layer_list.findText(name) + self.layer_list.removeItem(index) + logger.debug(f"Layer {name} has been removed from the menu") + + def _check_for_removed_layers(self): + """Check for layers that have been removed from the viewer and must be removed from the menu.""" + for layer in self._get_all_layers(): + self._check_for_removed_layer(layer) + + def _update_tooltip(self): + self.layer_list.setToolTip(self.layer_list.currentText()) + + def _update_description(self): + try: + if self.layer_list.currentText() != "": + try: + if self.layer_type == napari.layers.Points: + shape_desc = f"{len(self.layer_data())} keypoints" + else: + shape_desc = f"{self.layer_data().shape} frames" + self.layer_description.setText(shape_desc) + self.layer_description.setVisible(True) + except AttributeError: + self.layer_description.setVisible(False) + else: + self.layer_description.setVisible(False) + except KeyError: + self.layer_description.setVisible(False) + + def _add_layer(self, event): + inserted_layer = event.value + + if isinstance(inserted_layer, self.layer_type): + self.layer_list.addItem(inserted_layer.name) + + # check for renaming + inserted_layer.events.name.connect(self._rename_layer) + + def _rename_layer(self, _): + # on layer rename, check for removed/new layers + self._check_for_layers() + + def _remove_layer(self, event): + removed_layer = event.value + + if isinstance( + removed_layer, self.layer_type + ) and removed_layer.name in [ + self.layer_list.itemText(i) for i in range(self.layer_list.count()) + ]: + index = self.layer_list.findText(removed_layer.name) + self.layer_list.removeItem(index) + + def layer(self): + """Returns the layer selected in the dropdown menu.""" + try: + return self._viewer.layers[self.layer_name()] + except ValueError: + return None + + def layer_name(self): + """Returns the name of the layer selected in the dropdown menu.""" + try: + return self.layer_list.currentText() + except (KeyError, ValueError): + logger.warning("Layer list is empty") + return None + + def layer_data(self): + """Returns the data of the layer selected in the dropdown menu.""" + if self.layer_list.count() < 1: + logger.debug("Layer list is empty") + return None + try: + if self.layer_type == napari.layers.Points: + return self.layer().features + else: + return self.layer().data + except (KeyError, ValueError): + msg = f"Layer {self.layer_name()} has no data. Layer might have been renamed or removed." + logger.warning(msg) + return None + + +class Log(QTextEdit): + """Class to implement a log for important user info. Should be thread-safe.""" + + def __init__(self, parent=None): + """Creates a log with a lock for multithreading. + + Args: + parent (QWidget): parent widget to add Log instance to. + """ + super().__init__(parent) + + # from qtpy.QtCore import QMetaType + # parent.qRegisterMetaType("QTextCursor") + + self.lock = threading.Lock() + + def flush(self): + """Flush the log.""" + + def write(self, message): + """Write message to log in a thread-safe manner. + + Args: + message: string to be printed + """ + self.lock.acquire() + try: + if not hasattr(self, "flag"): + self.flag = False + message = message.replace("\r", "").rstrip() + if message: + method = "replace_last_line" if self.flag else "append" + QtCore.QMetaObject.invokeMethod( + self, + method, + QtCore.Qt.QueuedConnection, + QtCore.Q_ARG(str, message), + ) + self.flag = True + else: + self.flag = False + + finally: + self.lock.release() + + @QtCore.Slot(str) + def replace_last_line(self, text): + """Replace last line. For use in progress bar. + + Args: + text: string to be printed + """ + self.lock.acquire() + try: + cursor = self.textCursor() + cursor.movePosition(QTextCursor.End) + cursor.select(QTextCursor.BlockUnderCursor) + cursor.removeSelectedText() + cursor.insertBlock() + self.setTextCursor(cursor) + self.insertPlainText(text) + finally: + self.lock.release() + + def print_and_log(self, text, printing=True): + """Utility used to both print to terminal and log text to a QTextEdit item in a thread-safe manner. Use only for important user info. + + Args: + text (str): Text to be printed and logged + printing (bool): Whether to print the message as well or not using logger.info(). Defaults to True. + + """ + self.lock.acquire() + try: + if printing: + logger.info(text) + # causes issue if you clik on terminal (tied to CMD QuickEdit mode on Windows) + self.moveCursor(QTextCursor.End) + self.insertPlainText(f"\n{text}") + self.verticalScrollBar().setValue( + self.verticalScrollBar().maximum() + ) + finally: + self.lock.release() + + def warn(self, warning): + """Show logger.warning from another thread. + + Args: + warning: warning to be printed + """ + self.lock.acquire() + try: + logger.warning(warning) + finally: + self.lock.release() + + def error(self, error, msg=None): + """Show exception and message from another thread. + + Args: + error: error to be printed + msg: message to be printed + """ + self.lock.acquire() + try: + logger.error(error, exc_info=True) + if msg is not None: + self.print_and_log(f"{msg} : {error}", printing=False) + else: + self.print_and_log( + f"Exception caught in another thread : {error}", + printing=False, + ) + raise error + finally: + self.lock.release() diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py new file mode 100644 index 0000000..e426679 --- /dev/null +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -0,0 +1,269 @@ +# TODO : +# - Implement the tracking worker with multithreading +# - Implement a Log widget to display the tracking progress + a progress bar +# - Prepare I/O with the actual tracking backend +from pathlib import Path + +import napari +import numpy as np +from napari._qt.qthreading import GeneratorWorker +from qtpy.QtCore import Signal +from qtpy.QtWidgets import ( + QProgressBar, + QPushButton, + QSizePolicy, + QVBoxLayout, + QWidget, +) +from superqt.utils._qthreading import GeneratorWorkerSignals, WorkerBaseSignals + +from napari_deeplabcut._tracking_utils import ( + ContainerWidget, + LayerSelecter, + Log, + QWidgetSingleton, + add_widgets, +) + + +class TrackingModule(QWidget, metaclass=QWidgetSingleton): + """Plugin for tracking.""" + + def __init__(self, napari_viewer: "napari.viewer.Viewer"): + """Creates a widget with links to documentation and about page.""" + super().__init__() + self._viewer = napari_viewer + self._worker = None + self._keypoint_layer = None + ### Widgets ### + self.video_layer_dropdown = LayerSelecter( + self._viewer, + name="Video layer", + layer_type=napari.layers.Image, + parent=self, + ) + self.keypoint_layer_dropdown = LayerSelecter( + self._viewer, + name="Keypoint layer", + layer_type=napari.layers.Points, + parent=self, + ) + self.start_button = QPushButton("Start tracking") + self.start_button.clicked.connect(self._start) + ############################# + # status report docked widget + self.container_docked = False # check if already docked + + self.report_container = ContainerWidget(l=10, t=5, r=5, b=5) + + self.report_container.setSizePolicy( + QSizePolicy.Fixed, QSizePolicy.Minimum + ) + self.progress = QProgressBar(self.report_container) + self.progress.setVisible(False) + """Widget for the progress bar""" + + self.log = Log(self.report_container) + self.log.setVisible(False) + """Read-only display for process-related info. Use only for info destined to user.""" + self._build() + + # Use @property to get/set the keypoint layer + @property + def keypoint_layer(self): + """Get the keypoint layer.""" + return self._keypoint_layer + + @keypoint_layer.setter + def keypoint_layer(self, layer_name): + """Set the keypoint layer from the viewer.""" + for l in self._viewer.layers: + if l.name == layer_name: + self._keypoint_layer = l + break + + def _build(self): + """Create a TrackingModule plugin with : + + - A dropdown menu to select the keypoint layer + - A set of keypoints to track + - A button to start tracking + - A Log that shows when starting. providing feedback on the tracking process + """ + layout = QVBoxLayout() + + widgets = [ + self.video_layer_dropdown, + self.keypoint_layer_dropdown, + self.start_button, + ] + add_widgets(layout, widgets) + self.setLayout(layout) + + def _check_ready(self): + """Check if the inputs are ready for tracking.""" + if self.video_layer_dropdown.layer is None: + return False + if self.keypoint_layer_dropdown.layer is None: + return False + return True + + def _start(self): + """Start the tracking process.""" + # TODO : implement the tracking process + print("Started tracking") + print(f"Is ready : {self._check_ready()}") + return + if not self.check_ready(): + err = "Aborting, please choose valid inputs" + self.log.print_and_log(err) + raise ValueError(err) + + if self.worker is not None: + if self.worker.is_running: + pass + else: + self.worker.start() + self.btn_start.setText("Running... Click to stop") + else: + self.log.print_and_log("Starting...") + self.log.print_and_log("*" * 20) + # self._set_worker_config() + # if self.worker_config is None: + # raise RuntimeError("Worker config was not set correctly") + # self._setup_worker() + self.btn_close.setVisible(False) + + if self.worker.is_running: # if worker is running, tries to stop + self.log.print_and_log( + "Stop request, waiting for next inference..." + ) + self.btn_start.setText("Stopping...") + self.worker.quit() + else: # once worker is started, update buttons + self.worker.start() + self.btn_start.setText("Running... Click to stop") + + +class LogSignal(WorkerBaseSignals): + """Signal to send messages to be logged from another thread. + + Separate from Worker instances as indicated `on this post`_ + + .. _on this post: https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect + """ # TODO link ? + + log_signal = Signal(str) + """qtpy.QtCore.Signal: signal to be sent when some text should be logged""" + log_w_replace_signal = Signal(str) + """qtpy.QtCore.Signal: signal to be sent when some text should be logged, replacing the last line""" + warn_signal = Signal(str) + """qtpy.QtCore.Signal: signal to be sent when some warning should be emitted in main thread""" + error_signal = Signal(Exception, str) + """qtpy.QtCore.Signal: signal to be sent when some error should be emitted in main thread""" + + # Should not be an instance variable but a class variable, not defined in __init__, see + # https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect + + def __init__(self, parent=None): + """Creates a LogSignal.""" + super().__init__(parent=parent) + + +### -------- Tracking code -------- ### + + +class TrackingWorker(GeneratorWorker): + """A custom worker to run tracking in.""" + + def __init__(self, config=None): + """Creates a TrackingWorker.""" + super().__init__(self.run_tracking) + self._signals = LogSignal() + self.log_signal = self._signals.log_signal + self.log_w_replace_signal = self._signals.log_w_replace_signal + self.warn_signal = self._signals.warn_signal + self.error_signal = self._signals.error_signal + + self.config = config # use if needed + + def log(self, msg): + """Log a message.""" + self.log_signal.emit(msg) + + def log_w_replace(self, msg): + """Log a message, replacing the last line. For us with progress bars mainly.""" + self.log_w_replace_signal.emit(msg) + + def warn(self, msg): + """Log a warning.""" + self.warn_signal.emit(msg) + + +def track_mock( + video: Path, + keypoints: np.ndarray, +) -> np.ndarray: + """Mocks what a tracker would do. + + This method's signature should be re-used by all trackers (PIPS++ and CoTracker). + + Args: + video: The path to a video in which the points should be tracked. + keypoints: The position of keypoints to track in the video. This array should + have shape (n_animals, n_keypoints, 2), where + n_animals: the number of animals to track + n_keypoints: the number of keypoints to track for each individual + 2: as each point is defined by its (x, y) coordinates + + Returns: + an array of shape (num_frames, n_animals, n_keypoints, 2) corresponding to the + position of each keypoint in each frame of the video + """ + + def get_num_frames(video: Path) -> int: + return 0 + + return np.repeat(keypoints, (get_num_frames(video), 1, 1, 1)) + + +def track_cotracker( + video: Path, + keypoints: np.ndarray, +) -> np.ndarray: + """Tracks keypoints in a video using CoTracker. + + Args: + video: The path to a video in which the points should be tracked. + keypoints: The position of keypoints to track in the video. This array should + have shape (n_animals, n_keypoints, 2), where + n_animals: the number of animals to track + n_keypoints: the number of keypoints to track for each individual + 2: as each point is defined by its (x, y) coordinates + + Returns: + an array of shape (num_frames, n_animals, n_keypoints, 2) corresponding to the + position of each keypoint in each frame of the video + """ + # TODO: Implement your code here! + + +def track_pips( + video: Path, + keypoints: np.ndarray, +) -> np.ndarray: + """Tracks keypoints in a video using PIPS++. + + Args: + video: The path to a video in which the points should be tracked. + keypoints: The position of keypoints to track in the video. This array should + have shape (n_animals, n_keypoints, 2), where + n_animals: the number of animals to track + n_keypoints: the number of keypoints to track for each individual + 2: as each point is defined by its (x, y) coordinates + + Returns: + an array of shape (num_frames, n_animals, n_keypoints, 2) corresponding to the + position of each keypoint in each frame of the video + """ + # TODO: Implement your code here! From 49316ee58f39c94eff0bc2d1b78e50daa2257fee Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 Apr 2024 13:58:20 +0200 Subject: [PATCH 02/34] Update _tracking_worker.py --- src/napari_deeplabcut/_tracking_worker.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index e426679..837c015 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -114,6 +114,7 @@ def _start(self): print("Started tracking") print(f"Is ready : {self._check_ready()}") return + ### Below is code to start the worker and update the button for the use to start/stop the tracking process if not self.check_ready(): err = "Aborting, please choose valid inputs" self.log.print_and_log(err) @@ -199,6 +200,12 @@ def warn(self, msg): """Log a warning.""" self.warn_signal.emit(msg) + def run_tracking(self): + """Run the tracking.""" + # TODO : Implement the tracking process + self.log("Started tracking") + return + def track_mock( video: Path, From 2d5441b211047899f52e8d3e7b8ad4670412dc62 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 Apr 2024 14:17:18 +0200 Subject: [PATCH 03/34] Add missing manifest --- src/napari_deeplabcut/_tracking_worker.py | 15 ++++++++++++--- src/napari_deeplabcut/napari.yaml | 5 +++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index 837c015..72418f4 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -145,6 +145,15 @@ def _start(self): self.worker.start() self.btn_start.setText("Running... Click to stop") + def _setup_worker(self): + pass # TODO : Implement the worker setup + + def _on_yield(self, results): + pass + + +### -------- Tracking worker -------- ### + class LogSignal(WorkerBaseSignals): """Signal to send messages to be logged from another thread. @@ -171,9 +180,6 @@ def __init__(self, parent=None): super().__init__(parent=parent) -### -------- Tracking code -------- ### - - class TrackingWorker(GeneratorWorker): """A custom worker to run tracking in.""" @@ -204,6 +210,9 @@ def run_tracking(self): """Run the tracking.""" # TODO : Implement the tracking process self.log("Started tracking") + + # This must yield the tracking results for each frame to be displayed in the viewer + # yield ... ideally a class that contains data that can readily be used by napari return diff --git a/src/napari_deeplabcut/napari.yaml b/src/napari_deeplabcut/napari.yaml index f98c620..02f8907 100644 --- a/src/napari_deeplabcut/napari.yaml +++ b/src/napari_deeplabcut/napari.yaml @@ -26,6 +26,9 @@ contributions: - id: napari-deeplabcut.make_keypoint_controls python_name: napari_deeplabcut._widgets:KeypointControls title: Make keypoint controls + - id: napari-deeplabcut.tracking_demo + python_name: napari_deeplabcut._tracking_worker:TrackingModule + title: Tracking demo readers: - command: napari-deeplabcut.get_hdf_reader accepts_directories: false @@ -52,3 +55,5 @@ contributions: widgets: - command: napari-deeplabcut.make_keypoint_controls display_name: Keypoint controls + - command: napari-deeplabcut.tracking_demo + display_name: Tracking demo From a7b3c240e51cfe458cac37d22c1a6732bf5d3337 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 Apr 2024 14:57:47 +0200 Subject: [PATCH 04/34] Add basic worker signals --- src/napari_deeplabcut/_tracking_utils.py | 6 ++++ src/napari_deeplabcut/_tracking_worker.py | 44 +++++++++++++++++++++-- 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/src/napari_deeplabcut/_tracking_utils.py b/src/napari_deeplabcut/_tracking_utils.py index b9c628a..f10fcd9 100644 --- a/src/napari_deeplabcut/_tracking_utils.py +++ b/src/napari_deeplabcut/_tracking_utils.py @@ -1,4 +1,5 @@ ### ------------- Custom widgets for tracking module -------------- ### +import datetime import logging import threading from functools import partial @@ -53,6 +54,11 @@ def __init__( self.layout.setSizeConstraint(QLayout.SetFixedSize) +def get_time(): + """Get time in the following format : hour:minute:second. NOT COMPATIBLE with file paths (saving with ":" is invalid).""" + return f"{datetime.now():%H:%M:%S}" + + def add_widgets(layout, widgets): """Adds all widgets in the list to layout, with the specified alignment. diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index 72418f4..7bf0149 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -2,6 +2,7 @@ # - Implement the tracking worker with multithreading # - Implement a Log widget to display the tracking progress + a progress bar # - Prepare I/O with the actual tracking backend +from functools import partial from pathlib import Path import napari @@ -23,6 +24,7 @@ Log, QWidgetSingleton, add_widgets, + get_time, ) @@ -113,7 +115,7 @@ def _start(self): # TODO : implement the tracking process print("Started tracking") print(f"Is ready : {self._check_ready()}") - return + # TODO : setup worker ### Below is code to start the worker and update the button for the use to start/stop the tracking process if not self.check_ready(): err = "Aborting, please choose valid inputs" @@ -146,11 +148,49 @@ def _start(self): self.btn_start.setText("Running... Click to stop") def _setup_worker(self): - pass # TODO : Implement the worker setup + self.worker.started.connect(self.on_start) + + self.worker.log_signal.connect(self.log.print_and_log) + self.worker.log_w_replace_signal.connect(self.log.replace_last_line) + self.worker.warn_signal.connect(self.log.warn) + self.worker.error_signal.connect(self.log.error) + + self.worker.yielded.connect(partial(self.on_yield)) + self.worker.errored.connect(partial(self.on_error)) + self.worker.finished.connect(self.on_finish) def _on_yield(self, results): + # TODO : display the results in the viewer pass + def _on_start(self): + """Catches start signal from worker to call :py:func:`~display_status_report`.""" + self.display_status_report() + self._set_self_config() + self.log.print_and_log(f"Worker started at {get_time()}") + self.log.print_and_log(f"Saving results to : {self.results_path}") + self.log.print_and_log("Worker is running...") + + def _on_error(self, error): + """Catches errors and tries to clean up.""" + self.log.print_and_log("!" * 20) + self.log.print_and_log("Worker errored...") + self.log.error(error) + self.worker.quit() + self.on_finish() + + def _on_finish(self): + """Catches finished signal from worker, resets workspace for next run.""" + self.log.print_and_log(f"\nWorker finished at {get_time()}") + self.log.print_and_log("*" * 20) + self.btn_start.setText("Start") + self.btn_close.setVisible(True) + + self.worker = None + self.worker_config = None + self.empty_cuda_cache() + return True # signal clean exit + ### -------- Tracking worker -------- ### From 83a7bddf6e09048307211e71acd6fb5925c28500 Mon Sep 17 00:00:00 2001 From: Arash Sal Moslehian Date: Fri, 26 Apr 2024 13:30:12 +0000 Subject: [PATCH 05/34] tracking: fix migration errors --- src/napari_deeplabcut/_tracking_utils.py | 2 + src/napari_deeplabcut/_tracking_worker.py | 70 ++++++++++++----------- 2 files changed, 38 insertions(+), 34 deletions(-) diff --git a/src/napari_deeplabcut/_tracking_utils.py b/src/napari_deeplabcut/_tracking_utils.py index f10fcd9..a45e7bd 100644 --- a/src/napari_deeplabcut/_tracking_utils.py +++ b/src/napari_deeplabcut/_tracking_utils.py @@ -4,6 +4,8 @@ import threading from functools import partial from typing import Optional +from datetime import datetime + import napari from qtpy import QtCore diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index 7bf0149..ac145dc 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -102,7 +102,7 @@ def _build(self): add_widgets(layout, widgets) self.setLayout(layout) - def _check_ready(self): + def check_ready(self): """Check if the inputs are ready for tracking.""" if self.video_layer_dropdown.layer is None: return False @@ -114,7 +114,7 @@ def _start(self): """Start the tracking process.""" # TODO : implement the tracking process print("Started tracking") - print(f"Is ready : {self._check_ready()}") + print(f"Is ready : {self.check_ready()}") # TODO : setup worker ### Below is code to start the worker and update the button for the use to start/stop the tracking process if not self.check_ready(): @@ -122,42 +122,46 @@ def _start(self): self.log.print_and_log(err) raise ValueError(err) - if self.worker is not None: - if self.worker.is_running: + if self._worker is not None: + if self._worker.is_running: pass else: - self.worker.start() - self.btn_start.setText("Running... Click to stop") + self._worker.start() else: self.log.print_and_log("Starting...") self.log.print_and_log("*" * 20) - # self._set_worker_config() - # if self.worker_config is None: - # raise RuntimeError("Worker config was not set correctly") - # self._setup_worker() - self.btn_close.setVisible(False) + self._setup_worker() - if self.worker.is_running: # if worker is running, tries to stop + if self._worker.is_running: # if worker is running, tries to stop self.log.print_and_log( "Stop request, waiting for next inference..." ) - self.btn_start.setText("Stopping...") - self.worker.quit() + self._worker.quit() else: # once worker is started, update buttons - self.worker.start() - self.btn_start.setText("Running... Click to stop") + self._worker.start() def _setup_worker(self): - self.worker.started.connect(self.on_start) - self.worker.log_signal.connect(self.log.print_and_log) - self.worker.log_w_replace_signal.connect(self.log.replace_last_line) - self.worker.warn_signal.connect(self.log.warn) - self.worker.error_signal.connect(self.log.error) + self._worker = TrackingWorker() + self._worker.started.connect(self._on_start) - self.worker.yielded.connect(partial(self.on_yield)) - self.worker.errored.connect(partial(self.on_error)) - self.worker.finished.connect(self.on_finish) + self._worker.log_signal.connect(self.log.print_and_log) + self._worker.log_w_replace_signal.connect(self.log.replace_last_line) + self._worker.warn_signal.connect(self.log.warn) + self._worker.error_signal.connect(self.log.error) + + self._worker.yielded.connect(partial(self._on_yield)) + self._worker.errored.connect(partial(self._on_error)) + self._worker.finished.connect(self._on_finish) + + keypoint_cord = self.keypoint_layer_dropdown.layer_data() + frames = self.video_layer_dropdown.layer_data() + + self.log.print_and_log(f"keypoint started at {keypoint_cord}") + self.log.print_and_log(f"frames started at {frames}") + + + def _on_yield(self, results): # TODO : display the results in the viewer @@ -165,10 +169,10 @@ def _on_yield(self, results): def _on_start(self): """Catches start signal from worker to call :py:func:`~display_status_report`.""" - self.display_status_report() - self._set_self_config() + # self.display_status_report() + # self._set_self_config() self.log.print_and_log(f"Worker started at {get_time()}") - self.log.print_and_log(f"Saving results to : {self.results_path}") + #self.log.print_and_log(f"Saving results to : {self.results_path}") self.log.print_and_log("Worker is running...") def _on_error(self, error): @@ -176,19 +180,17 @@ def _on_error(self, error): self.log.print_and_log("!" * 20) self.log.print_and_log("Worker errored...") self.log.error(error) - self.worker.quit() + self._worker.quit() self.on_finish() def _on_finish(self): """Catches finished signal from worker, resets workspace for next run.""" self.log.print_and_log(f"\nWorker finished at {get_time()}") self.log.print_and_log("*" * 20) - self.btn_start.setText("Start") - self.btn_close.setVisible(True) - self.worker = None - self.worker_config = None - self.empty_cuda_cache() + self._worker = None + self._worker_config = None + # self.empty_cuda_cache() return True # signal clean exit @@ -253,7 +255,7 @@ def run_tracking(self): # This must yield the tracking results for each frame to be displayed in the viewer # yield ... ideally a class that contains data that can readily be used by napari - return + yield def track_mock( From 4a3b7a1e22df5e452d0f1c204bd05fdc2b97d924 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 Apr 2024 15:37:24 +0200 Subject: [PATCH 06/34] Add proper worker start and yield --- src/napari_deeplabcut/_tracking_utils.py | 2 +- src/napari_deeplabcut/_tracking_worker.py | 115 ++++++++++++++-------- 2 files changed, 76 insertions(+), 41 deletions(-) diff --git a/src/napari_deeplabcut/_tracking_utils.py b/src/napari_deeplabcut/_tracking_utils.py index f10fcd9..a0c5dc7 100644 --- a/src/napari_deeplabcut/_tracking_utils.py +++ b/src/napari_deeplabcut/_tracking_utils.py @@ -1,7 +1,7 @@ ### ------------- Custom widgets for tracking module -------------- ### -import datetime import logging import threading +from datetime import datetime from functools import partial from typing import Optional diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index 7bf0149..e822699 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -110,65 +110,96 @@ def _check_ready(self): return False return True + def _display_status_report(self): + """Adds a text log, a progress bar and a "save log" button on the left side of the viewer (usually when starting a worker).""" + + if self.container_docked: + self.log.clear() + elif not self.container_docked: + add_widgets( + self.report_container.layout, + [self.progress, self.log], + ) + self.report_container.setLayout(self.report_container.layout) + report_dock = self._viewer.window.add_dock_widget( + self.report_container, + name="Status report", + area="left", + allowed_areas=["left", "right"], + ) + report_dock._close_btn = False + + # self.docked_widgets.append(report_dock) + self.container_docked = True + + self.log.setVisible(True) + self.progress.setVisible(True) + self.progress.setValue(0) + + def _update_progress_bar(self, current_frame, total_frame): + """Update the progress bar.""" + pbar_value = (current_frame / total_frame) * 100 + if pbar_value > 100: + pbar_value = 100 + + self.progress.setValue(pbar_value) + def _start(self): """Start the tracking process.""" - # TODO : implement the tracking process print("Started tracking") - print(f"Is ready : {self._check_ready()}") - # TODO : setup worker ### Below is code to start the worker and update the button for the use to start/stop the tracking process - if not self.check_ready(): - err = "Aborting, please choose valid inputs" - self.log.print_and_log(err) - raise ValueError(err) + # if not self._check_ready(): + # err = "Aborting, please choose valid inputs" + # self.log.print_and_log(err) + # raise ValueError(err) - if self.worker is not None: - if self.worker.is_running: + if self._worker is not None: + if self._worker.is_running: pass else: - self.worker.start() - self.btn_start.setText("Running... Click to stop") + self._worker.start() + self.start_button.setText("Running... Click to stop") else: self.log.print_and_log("Starting...") self.log.print_and_log("*" * 20) - # self._set_worker_config() - # if self.worker_config is None: - # raise RuntimeError("Worker config was not set correctly") - # self._setup_worker() - self.btn_close.setVisible(False) + self._setup_worker() - if self.worker.is_running: # if worker is running, tries to stop + if self._worker.is_running: # if worker is running, tries to stop self.log.print_and_log( "Stop request, waiting for next inference..." ) - self.btn_start.setText("Stopping...") - self.worker.quit() + self.start_button.setText("Stopping...") + self._worker.quit() else: # once worker is started, update buttons - self.worker.start() - self.btn_start.setText("Running... Click to stop") + self._worker.start() + self.start_button.setText("Running... Click to stop") def _setup_worker(self): - self.worker.started.connect(self.on_start) + self._worker = TrackingWorker() + + self._worker.started.connect(self._on_start) - self.worker.log_signal.connect(self.log.print_and_log) - self.worker.log_w_replace_signal.connect(self.log.replace_last_line) - self.worker.warn_signal.connect(self.log.warn) - self.worker.error_signal.connect(self.log.error) + self._worker.log_signal.connect(self.log.print_and_log) + self._worker.log_w_replace_signal.connect(self.log.replace_last_line) + self._worker.warn_signal.connect(self.log.warn) + self._worker.error_signal.connect(self.log.error) - self.worker.yielded.connect(partial(self.on_yield)) - self.worker.errored.connect(partial(self.on_error)) - self.worker.finished.connect(self.on_finish) + self._worker.yielded.connect(partial(self._on_yield)) + self._worker.errored.connect(partial(self._on_error)) + self._worker.finished.connect(self._on_finish) def _on_yield(self, results): # TODO : display the results in the viewer - pass + # Testing version where an int i is yielded + ############################ + self.log.print_and_log(f"Yielded {results}") + self._update_progress_bar(results, 10) + ############################ def _on_start(self): """Catches start signal from worker to call :py:func:`~display_status_report`.""" - self.display_status_report() - self._set_self_config() + self._display_status_report() self.log.print_and_log(f"Worker started at {get_time()}") - self.log.print_and_log(f"Saving results to : {self.results_path}") self.log.print_and_log("Worker is running...") def _on_error(self, error): @@ -176,19 +207,16 @@ def _on_error(self, error): self.log.print_and_log("!" * 20) self.log.print_and_log("Worker errored...") self.log.error(error) - self.worker.quit() + self._worker.quit() self.on_finish() def _on_finish(self): """Catches finished signal from worker, resets workspace for next run.""" self.log.print_and_log(f"\nWorker finished at {get_time()}") self.log.print_and_log("*" * 20) - self.btn_start.setText("Start") - self.btn_close.setVisible(True) + self.start_button.setText("Start") - self.worker = None - self.worker_config = None - self.empty_cuda_cache() + self._worker = None return True # signal clean exit @@ -225,7 +253,8 @@ class TrackingWorker(GeneratorWorker): def __init__(self, config=None): """Creates a TrackingWorker.""" - super().__init__(self.run_tracking) + # super().__init__(self.run_tracking) #### TODO MUST BE CHANGED WHEN REAL TRACKING IS IMPLEMENTED + super().__init__(self.fake_tracking) self._signals = LogSignal() self.log_signal = self._signals.log_signal self.log_w_replace_signal = self._signals.log_w_replace_signal @@ -255,6 +284,12 @@ def run_tracking(self): # yield ... ideally a class that contains data that can readily be used by napari return + def fake_tracking(self): + """Fake tracking for testing purposes.""" + for i in range(10): + self.log(f"Tracking frame {i}") + yield i + 1 + def track_mock( video: Path, From a06ce570b6858ee1dfb0d5a99519cd9e3339ec8e Mon Sep 17 00:00:00 2001 From: Arash Sal Moslehian Date: Fri, 26 Apr 2024 13:30:12 +0000 Subject: [PATCH 07/34] tracking: fix migration errors --- src/napari_deeplabcut/_tracking_worker.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index e822699..1007356 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -102,7 +102,7 @@ def _build(self): add_widgets(layout, widgets) self.setLayout(layout) - def _check_ready(self): + def check_ready(self): """Check if the inputs are ready for tracking.""" if self.video_layer_dropdown.layer is None: return False @@ -188,6 +188,12 @@ def _setup_worker(self): self._worker.errored.connect(partial(self._on_error)) self._worker.finished.connect(self._on_finish) + keypoint_cord = self.keypoint_layer_dropdown.layer_data() + frames = self.video_layer_dropdown.layer_data() + + self.log.print_and_log(f"keypoint started at {keypoint_cord}") + self.log.print_and_log(f"frames started at {frames}") + def _on_yield(self, results): # TODO : display the results in the viewer # Testing version where an int i is yielded @@ -282,7 +288,7 @@ def run_tracking(self): # This must yield the tracking results for each frame to be displayed in the viewer # yield ... ideally a class that contains data that can readily be used by napari - return + yield def fake_tracking(self): """Fake tracking for testing purposes.""" From 2af5f433df2f8682f7a6469e403863381ee2e764 Mon Sep 17 00:00:00 2001 From: Niels Poulsen Date: Fri, 26 Apr 2024 15:59:35 +0200 Subject: [PATCH 08/34] implemented run_tracking --- src/napari_deeplabcut/_tracking_worker.py | 22 ++++++++++------------ 1 file changed, 10 insertions(+), 12 deletions(-) diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index ac145dc..c9fef24 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -248,18 +248,20 @@ def warn(self, msg): """Log a warning.""" self.warn_signal.emit(msg) - def run_tracking(self): + def run_tracking( + self, + video: np.ndarray, + keypoints: np.ndarray, + ): """Run the tracking.""" - # TODO : Implement the tracking process self.log("Started tracking") - - # This must yield the tracking results for each frame to be displayed in the viewer - # yield ... ideally a class that contains data that can readily be used by napari - yield + tracks = track_mock(video, keypoints) + self.log("Finished tracking") + yield tracks def track_mock( - video: Path, + video: np.ndarray, keypoints: np.ndarray, ) -> np.ndarray: """Mocks what a tracker would do. @@ -278,11 +280,7 @@ def track_mock( an array of shape (num_frames, n_animals, n_keypoints, 2) corresponding to the position of each keypoint in each frame of the video """ - - def get_num_frames(video: Path) -> int: - return 0 - - return np.repeat(keypoints, (get_num_frames(video), 1, 1, 1)) + return np.repeat(keypoints, (len(video), 1, 1, 1)) def track_cotracker( From 27498e52980be93457f98f77385164ed8336bc08 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 Apr 2024 18:36:02 +0200 Subject: [PATCH 09/34] Try to display colors --- src/napari_deeplabcut/_tracking_utils.py | 4 +- src/napari_deeplabcut/_tracking_worker.py | 37 ++++- src/napari_deeplabcut/_widgets.py | 193 +++++++++++++++------- 3 files changed, 161 insertions(+), 73 deletions(-) diff --git a/src/napari_deeplabcut/_tracking_utils.py b/src/napari_deeplabcut/_tracking_utils.py index 5a7ab2b..56acb93 100644 --- a/src/napari_deeplabcut/_tracking_utils.py +++ b/src/napari_deeplabcut/_tracking_utils.py @@ -4,8 +4,6 @@ from datetime import datetime from functools import partial from typing import Optional -from datetime import datetime - import napari from qtpy import QtCore @@ -306,7 +304,7 @@ def layer_data(self): return None try: if self.layer_type == napari.layers.Points: - return self.layer().features + return self.layer().data else: return self.layer().data except (KeyError, ValueError): diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index 30d3088..c060a37 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -26,14 +26,15 @@ add_widgets, get_time, ) +from napari_deeplabcut.keypoints import KeypointStore class TrackingModule(QWidget, metaclass=QWidgetSingleton): """Plugin for tracking.""" - def __init__(self, napari_viewer: "napari.viewer.Viewer"): + def __init__(self, napari_viewer: "napari.viewer.Viewer", parent=None): """Creates a widget with links to documentation and about page.""" - super().__init__() + super().__init__(parent=parent) self._viewer = napari_viewer self._worker = None self._keypoint_layer = None @@ -102,7 +103,7 @@ def _build(self): add_widgets(layout, widgets) self.setLayout(layout) - def check_ready(self): + def _check_ready(self): """Check if the inputs are ready for tracking.""" if self.video_layer_dropdown.layer is None: return False @@ -149,10 +150,10 @@ def _start(self): print("Started tracking") ### Below is code to start the worker and update the button for the use to start/stop the tracking process - # if not self._check_ready(): - # err = "Aborting, please choose valid inputs" - # self.log.print_and_log(err) - # raise ValueError(err) + if not self._check_ready(): + err = "Aborting, please choose valid inputs" + self.log.print_and_log(err) + raise ValueError(err) if self._worker is not None: if self._worker.is_running: @@ -196,10 +197,29 @@ def _setup_worker(self): self.log.print_and_log(f"keypoint started at {keypoint_cord}") self.log.print_and_log(f"frames started at {frames}") + def _display_results(self, results): + """Display the results in the viewer, using the method already implemented in the viewer.""" + path_test = "C:/Users/Cyril/Desktop/Code/DeepLabCut/examples/openfield-Pranav-2018-10-30/labeled-data/m4s1/CollectedData_Pranav.h5" + from napari_deeplabcut._reader import read_config, read_hdf + + keypoint_data, metadata, _ = read_hdf(path_test)[0] + # hdf data contains : keypoint data, metadata, and "points" + # we want to create a points layer from the keypoint data + # layer properties (dict) should be populated with metadata + print(metadata) + layer = self._viewer.add_points( + keypoint_data, + name="keypoints_hdf_test", + metadata=metadata["metadata"], + features=metadata["properties"], + properties=metadata["properties"], + ) + self.parent().parent().update_layer(layer) def _on_yield(self, results): # TODO : display the results in the viewer # Testing version where an int i is yielded + self._display_results(results) ############################ self.log.print_and_log(f"Yielded {results}") self._update_progress_bar(results, 10) @@ -295,12 +315,11 @@ def run_tracking( def fake_tracking(self): """Fake tracking for testing purposes.""" - for i in range(10): + for i in range(1): self.log(f"Tracking frame {i}") yield i + 1 - def track_mock( video: np.ndarray, keypoints: np.ndarray, diff --git a/src/napari_deeplabcut/_widgets.py b/src/napari_deeplabcut/_widgets.py index d1512cc..cc93d5d 100644 --- a/src/napari_deeplabcut/_widgets.py +++ b/src/napari_deeplabcut/_widgets.py @@ -3,19 +3,21 @@ from collections import defaultdict, namedtuple from copy import deepcopy from datetime import datetime -from functools import partial, cached_property +from functools import cached_property, partial from math import ceil, log10 -import matplotlib.pyplot as plt -import matplotlib.style as mplstyle -import napari -import pandas as pd from pathlib import Path from types import MethodType from typing import Optional, Sequence, Union -from matplotlib.backends.backend_qtagg import FigureCanvas, NavigationToolbar2QT - +import matplotlib.pyplot as plt +import matplotlib.style as mplstyle +import napari import numpy as np +import pandas as pd +from matplotlib.backends.backend_qtagg import ( + FigureCanvas, + NavigationToolbar2QT, +) from napari._qt.widgets.qt_welcome import QtWelcomeLabel from napari.layers import Image, Points, Shapes, Tracks from napari.layers.points._points_key_bindings import register_points_action @@ -23,8 +25,8 @@ from napari.layers.utils.layer_utils import _features_to_properties from napari.utils.events import Event from napari.utils.history import get_save_history, update_save_history -from qtpy.QtCore import Qt, QTimer, Signal, QPoint, QSettings, QSize -from qtpy.QtGui import QPainter, QAction, QCursor, QIcon +from qtpy.QtCore import QPoint, QSettings, QSize, Qt, QTimer, Signal +from qtpy.QtGui import QAction, QCursor, QIcon, QPainter from qtpy.QtSvgWidgets import QSvgWidget from qtpy.QtWidgets import ( QButtonGroup, @@ -50,12 +52,13 @@ from napari_deeplabcut import keypoints from napari_deeplabcut._reader import _load_config -from napari_deeplabcut._writer import _write_config, _write_image, _form_df +from napari_deeplabcut._tracking_worker import TrackingModule +from napari_deeplabcut._writer import _form_df, _write_config, _write_image from napari_deeplabcut.misc import ( + build_color_cycles, encode_categories, - to_os_dir_sep, guarantee_multiindex_rows, - build_color_cycles, + to_os_dir_sep, ) Tip = namedtuple("Tip", ["msg", "pos"]) @@ -69,7 +72,9 @@ def __init__(self, parent): self.setParent(parent) self.setWindowTitle("Shortcuts") - image_path = str(Path(__file__).parent / "assets" / "napari_shortcuts.svg") + image_path = str( + Path(__file__).parent / "assets" / "napari_shortcuts.svg" + ) vlayout = QVBoxLayout() svg_widget = QSvgWidget(image_path) @@ -118,7 +123,9 @@ def __init__(self, parent): ] vlayout = QVBoxLayout() - self.message = QLabel("💡\n\nLet's get started with a quick walkthrough!") + self.message = QLabel( + "💡\n\nLet's get started with a quick walkthrough!" + ) self.message.setTextInteractionFlags(Qt.LinksAccessibleByMouse) self.message.setOpenExternalLinks(True) vlayout.addWidget(self.message) @@ -194,10 +201,7 @@ def _get_and_try_preferred_reader( # where property is understood as continuous if # there are more than 16 unique categories... def guess_continuous(property): - if issubclass(property.dtype.type, np.floating): - return True - else: - return False + return bool(issubclass(property.dtype.type, np.floating)) color_manager.guess_continuous = guess_continuous @@ -236,12 +240,17 @@ def _paste_data(self, store): not_disp = self._slice_input.not_displayed data = deepcopy(self._clipboard["data"]) offset = [ - self._slice_indices[i] - self._clipboard["indices"][i] for i in not_disp + self._slice_indices[i] - self._clipboard["indices"][i] + for i in not_disp ] data[:, not_disp] = data[:, not_disp] + np.array(offset) self._data = np.append(self.data, data, axis=0) - self._shown = np.append(self.shown, deepcopy(self._clipboard["shown"]), axis=0) - self._size = np.append(self.size, deepcopy(self._clipboard["size"]), axis=0) + self._shown = np.append( + self.shown, deepcopy(self._clipboard["shown"]), axis=0 + ) + self._size = np.append( + self.size, deepcopy(self._clipboard["size"]), axis=0 + ) self._symbol = np.append( self.symbol, deepcopy(self._clipboard["symbol"]), axis=0 ) @@ -405,6 +414,8 @@ def __init__(self, napari_viewer, parent=None): layout.addLayout(layout2) self.setLayout(layout) + ############################ + self.frames = [] self.keypoints = [] self.df = None @@ -497,7 +508,9 @@ def _load_dataframe(self): if points_layer is None or ~np.any(points_layer.data): return - self.viewer.window.add_dock_widget(self, name="Trajectory plot", area="right") + self.viewer.window.add_dock_widget( + self, name="Trajectory plot", area="right" + ) self.hide() self.df = _form_df( @@ -508,9 +521,13 @@ def _load_dataframe(self): }, ) for keypoint in self.df.columns.get_level_values("bodyparts").unique(): - y = self.df.xs((keypoint, "y"), axis=1, level=["bodyparts", "coords"]) + y = self.df.xs( + (keypoint, "y"), axis=1, level=["bodyparts", "coords"] + ) x = np.arange(len(y)) - color = points_layer.metadata["face_color_cycles"]["label"][keypoint] + color = points_layer.metadata["face_color_cycles"]["label"][ + keypoint + ] lines = self.ax.plot(x, y, color=color, label=keypoint) self._lines[keypoint] = lines @@ -553,9 +570,11 @@ def __init__(self, napari_viewer): self.viewer.layers.events.inserted.connect(self.on_insert) self.viewer.layers.events.removed.connect(self.on_remove) - self.viewer.window.qt_viewer._get_and_try_preferred_reader = MethodType( - _get_and_try_preferred_reader, - self.viewer.window.qt_viewer, + self.viewer.window.qt_viewer._get_and_try_preferred_reader = ( + MethodType( + _get_and_try_preferred_reader, + self.viewer.window.qt_viewer, + ) ) status_bar = self.viewer.window._qt_window.statusBar() @@ -593,7 +612,9 @@ def __init__(self, napari_viewer): self._layout = QVBoxLayout(self) self._menus = [] self._layer_to_menu = {} - self.viewer.layers.selection.events.active.connect(self.on_active_layer_change) + self.viewer.layers.selection.events.active.connect( + self.on_active_layer_change + ) self._video_group = self._form_video_action_menu() self.video_widget = self.viewer.window.add_dock_widget( @@ -635,9 +656,14 @@ def __init__(self, napari_viewer): self._radio_group = self._form_mode_radio_buttons() # form color scheme display + color mode selector - self._color_grp, self._color_mode_selector = self._form_color_mode_selector() + ( + self._color_grp, + self._color_mode_selector, + ) = self._form_color_mode_selector() self._display = ColorSchemeDisplay(parent=self) - self._color_scheme_display = self._form_color_scheme_display(self.viewer) + self._color_scheme_display = self._form_color_scheme_display( + self.viewer + ) self._view_scheme_cb.toggled.connect(self._show_color_scheme) self._view_scheme_cb.toggle() self._display.added.connect( @@ -646,6 +672,10 @@ def __init__(self, napari_viewer): ), ) + # create a QtGroup for the tracking module + self._tracking_group = self._form_tracking_module() + self._layout.addWidget(self._tracking_group) + # Substitute default menu action with custom one for action in self.viewer.window.file_menu.actions()[::-1]: action_name = action.text().lower() @@ -674,8 +704,12 @@ def __init__(self, napari_viewer): self.viewer.window._qt_viewer.viewerButtons.gridViewButton.hide() self.viewer.window._qt_viewer.viewerButtons.rollDimsButton.hide() self.viewer.window._qt_viewer.viewerButtons.transposeDimsButton.hide() - self.viewer.window._qt_viewer.layerButtons.newPointsButton.setDisabled(True) - self.viewer.window._qt_viewer.layerButtons.newLabelsButton.setDisabled(True) + self.viewer.window._qt_viewer.layerButtons.newPointsButton.setDisabled( + True + ) + self.viewer.window._qt_viewer.layerButtons.newLabelsButton.setDisabled( + True + ) if self.settings.value("first_launch", True) and not os.environ.get( "hide_tutorial", False @@ -755,6 +789,14 @@ def _form_help_buttons(self): layout.addWidget(tutorial) return layout + def _form_tracking_module(self): + group_box = QGroupBox("Tracking") + tracking = TrackingModule(napari_viewer=self.viewer, parent=self) + layout = QVBoxLayout() + layout.addWidget(tracking) + group_box.setLayout(layout) + return group_box + def _extract_single_frame(self, *args): image_layer = None points_layer = None @@ -780,8 +822,10 @@ def _extract_single_frame(self, *args): "properties": points_layer.properties, }, ) - df = df.iloc[ind: ind + 1] - df.index = pd.MultiIndex.from_tuples([Path(output_path).parts[-3:]]) + df = df.iloc[ind : ind + 1] + df.index = pd.MultiIndex.from_tuples( + [Path(output_path).parts[-3:]] + ) filepath = os.path.join( image_layer.metadata["root"], "machinelabels-iter0.h5" ) @@ -811,7 +855,9 @@ def _store_crop_coordinates(self, *args): config_path = os.path.join(project_path, "config.yaml") cfg = _load_config(config_path) cfg["video_sets"][ - os.path.join(project_path, "videos", self._images_meta["name"]) + os.path.join( + project_path, "videos", self._images_meta["name"] + ) ] = temp _write_config(config_path, cfg) break @@ -875,7 +921,10 @@ def _form_color_scheme_display(self, viewer): def _update_color_scheme(self): def to_hex(nparray): a = np.array(nparray * 255, dtype=int) - rgb2hex = lambda r, g, b, _: f"#{r:02x}{g:02x}{b:02x}" + + def rgb2hex(r, g, b, _): + return f"#{r:02x}{g:02x}{b:02x}" + res = rgb2hex(*a) return res @@ -889,7 +938,9 @@ def to_hex(nparray): self._display.update_color_scheme( { name: to_hex(color) - for name, color in layer.metadata["face_color_cycles"][mode].items() + for name, color in layer.metadata["face_color_cycles"][ + mode + ].items() } ) @@ -902,7 +953,9 @@ def _remap_frame_indices(self, layer): if paths is not None and np.any(layer.data): paths_map = dict(zip(range(len(paths)), map(to_os_dir_sep, paths))) # Discard data if there are missing frames - missing = [i for i, path in paths_map.items() if path not in new_paths] + missing = [ + i for i, path in paths_map.items() if path not in new_paths + ] if missing: if isinstance(layer.data, list): inds_to_remove = [ @@ -911,7 +964,9 @@ def _remap_frame_indices(self, layer): if verts[0, 0] in missing ] else: - inds_to_remove = np.flatnonzero(np.isin(layer.data[:, 0], missing)) + inds_to_remove = np.flatnonzero( + np.isin(layer.data[:, 0], missing) + ) layer.selected_data = inds_to_remove layer.remove_selected() for i in missing: @@ -930,6 +985,9 @@ def _remap_frame_indices(self, layer): def on_insert(self, event): layer = event.source[-1] + self.update_layer(layer, event) + + def update_layer(self, layer, event=None): logging.debug(f"Inserting Layer {layer}") if isinstance(layer, Image): paths = layer.metadata.get("paths") @@ -945,9 +1003,9 @@ def on_insert(self, event): } ) # Delay layer sorting - QTimer.singleShot( - 10, partial(self._move_image_layer_to_bottom, event.index) - ) + # QTimer.singleShot( + # 10, partial(self._move_image_layer_to_bottom, event.index) + # ) elif isinstance(layer, Points): # If the current Points layer comes from a config file, some have already # been added and the body part names are different from the existing ones, @@ -957,13 +1015,17 @@ def on_insert(self, event): keypoints_menu = self._menus[0].menus["label"] current_keypoint_set = set( - keypoints_menu.itemText(i) for i in range(keypoints_menu.count()) + keypoints_menu.itemText(i) + for i in range(keypoints_menu.count()) ) new_keypoint_set = set(new_metadata["header"].bodyparts) diff = new_keypoint_set.difference(current_keypoint_set) + print(diff) if diff: answer = QMessageBox.question( - self, "", "Do you want to display the new keypoints only?" + self, + "", + "Do you want to display the new keypoints only?", ) if answer == QMessageBox.Yes: self.viewer.layers[-2].shown = False @@ -986,9 +1048,9 @@ def on_insert(self, event): "face_color_cycles" ] _layer.face_color = "label" - _layer.face_color_cycle = new_metadata["face_color_cycles"][ - "label" - ] + _layer.face_color_cycle = new_metadata[ + "face_color_cycles" + ]["label"] _layer.events.face_color() store.layer = _layer self._update_color_scheme() @@ -1073,9 +1135,9 @@ def on_remove(self, event): def on_active_layer_change(self, event) -> None: """Updates the GUI when the active layer changes - * Hides all KeypointsDropdownMenu that aren't for the selected layer - * Sets the visibility of the "Color mode" box to True if the selected layer - is a multi-animal one, or False otherwise + * Hides all KeypointsDropdownMenu that aren't for the selected layer + * Sets the visibility of the "Color mode" box to True if the selected layer + is a multi-animal one, or False otherwise """ self._color_grp.setVisible(self._is_multianimal(event.value)) menu_idx = -1 @@ -1092,7 +1154,8 @@ def _update_colormap(self, colormap_name): for layer in self.viewer.layers.selection: if isinstance(layer, Points) and layer.metadata: face_color_cycle_maps = build_color_cycles( - layer.metadata["header"], colormap_name, + layer.metadata["header"], + colormap_name, ) layer.metadata["face_color_cycles"] = face_color_cycle_maps face_color_prop = "label" @@ -1110,9 +1173,8 @@ def cycle_through_label_modes(self, *args): @register_points_action("Change color mode") def cycle_through_color_modes(self, *args): - if ( - self._active_layer_is_multianimal() - or self.color_mode != str(keypoints.ColorMode.BODYPART) + if self._active_layer_is_multianimal() or self.color_mode != str( + keypoints.ColorMode.BODYPART ): self.color_mode = next(keypoints.ColorMode) @@ -1152,7 +1214,8 @@ def color_mode(self, mode: Union[str, keypoints.ColorMode]): if isinstance(layer, Points) and layer.metadata: layer.face_color = face_color_mode layer.face_color_cycle = layer.metadata["face_color_cycles"][ - face_color_mode] + face_color_mode + ] layer.events.face_color() for btn in self._color_mode_selector.buttons(): @@ -1191,7 +1254,9 @@ def toggle_edge_color(layer): class DropdownMenu(QComboBox): - def __init__(self, labels: Sequence[str], parent: Optional[QWidget] = None): + def __init__( + self, labels: Sequence[str], parent: Optional[QWidget] = None + ): super().__init__(parent) self.update_items(labels) @@ -1461,7 +1526,9 @@ def _format_label(label: QLabel, height: int = None, width: int = None): def _build(self): layout = QHBoxLayout() - layout.addWidget(self.color_label, alignment=Qt.AlignmentFlag.AlignLeft) + layout.addWidget( + self.color_label, alignment=Qt.AlignmentFlag.AlignLeft + ) layout.addWidget(self.part_label, alignment=Qt.AlignmentFlag.AlignLeft) self.setLayout(layout) @@ -1527,7 +1594,9 @@ def _build(self): self.setBaseSize(100, 200) self.setVerticalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOn) - self.setHorizontalScrollBarPolicy(Qt.ScrollBarPolicy.ScrollBarAlwaysOff) + self.setHorizontalScrollBarPolicy( + Qt.ScrollBarPolicy.ScrollBarAlwaysOff + ) def add_entry(self, name, color): self.scheme_dict.update({name: color}) @@ -1538,7 +1607,9 @@ def add_entry(self, name, color): def update_color_scheme(self, new_color_scheme) -> None: logging.debug(f"Updating color scheme: {self._layout.count()} widgets") - self.scheme_dict = {name: color for name, color in new_color_scheme.items()} + self.scheme_dict = { + name: color for name, color in new_color_scheme.items() + } names = list(new_color_scheme.keys()) existing_widgets = self._layout.count() required_widgets = len(self.scheme_dict) @@ -1555,7 +1626,7 @@ def update_color_scheme(self, new_color_scheme) -> None: for i in range(max(existing_widgets - required_widgets, 0)): logging.debug(f" hiding {required_widgets + i}") if w := self._layout.itemAt(required_widgets + i).widget(): - logging.debug(f" done!") + logging.debug(" done!") w.setVisible(False) # add missing widgets @@ -1563,7 +1634,7 @@ def update_color_scheme(self, new_color_scheme) -> None: logging.debug(f" adding {existing_widgets + i}") name = names[existing_widgets + i] self.add_entry(name, self.scheme_dict[name]) - logging.debug(f" done!") + logging.debug(" done!") def reset(self): self.scheme_dict = {} From 45c7472915245e792d1157020cc70ae4eb234c2e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 Apr 2024 18:38:12 +0200 Subject: [PATCH 10/34] Update _tracking_worker.py --- src/napari_deeplabcut/_tracking_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index c060a37..20179e6 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -200,7 +200,7 @@ def _setup_worker(self): def _display_results(self, results): """Display the results in the viewer, using the method already implemented in the viewer.""" path_test = "C:/Users/Cyril/Desktop/Code/DeepLabCut/examples/openfield-Pranav-2018-10-30/labeled-data/m4s1/CollectedData_Pranav.h5" - from napari_deeplabcut._reader import read_config, read_hdf + from napari_deeplabcut._reader import read_hdf keypoint_data, metadata, _ = read_hdf(path_test)[0] # hdf data contains : keypoint data, metadata, and "points" From e11ff4cbf24d5bf9eba9e3ae2b4063d3adc42aa2 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Fri, 26 Apr 2024 18:59:56 +0200 Subject: [PATCH 11/34] Fix color display --- src/napari_deeplabcut/_tracking_worker.py | 18 ++++++++++++------ src/napari_deeplabcut/_widgets.py | 4 +++- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index 20179e6..024fffb 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -26,8 +26,7 @@ add_widgets, get_time, ) -from napari_deeplabcut.keypoints import KeypointStore - +from napari.layers.utils.layer_utils import _features_to_properties class TrackingModule(QWidget, metaclass=QWidgetSingleton): """Plugin for tracking.""" @@ -194,8 +193,6 @@ def _setup_worker(self): keypoint_cord = self.keypoint_layer_dropdown.layer_data() frames = self.video_layer_dropdown.layer_data() - self.log.print_and_log(f"keypoint started at {keypoint_cord}") - self.log.print_and_log(f"frames started at {frames}") def _display_results(self, results): """Display the results in the viewer, using the method already implemented in the viewer.""" @@ -208,13 +205,22 @@ def _display_results(self, results): # layer properties (dict) should be populated with metadata print(metadata) layer = self._viewer.add_points( + ### data ### keypoint_data, name="keypoints_hdf_test", metadata=metadata["metadata"], - features=metadata["properties"], + # features=metadata["properties"], properties=metadata["properties"], + ### display properties ### + face_color=metadata["face_color"], + face_color_cycle=metadata["face_color_cycle"], + face_colormap=metadata["face_colormap"], + edge_color=metadata["edge_color"], + edge_color_cycle=metadata["edge_color_cycle"], + edge_width=metadata["edge_width"], + edge_width_is_relative=metadata["edge_width_is_relative"], + size=metadata["size"], ) - self.parent().parent().update_layer(layer) def _on_yield(self, results): # TODO : display the results in the viewer diff --git a/src/napari_deeplabcut/_widgets.py b/src/napari_deeplabcut/_widgets.py index cc93d5d..2fcf430 100644 --- a/src/napari_deeplabcut/_widgets.py +++ b/src/napari_deeplabcut/_widgets.py @@ -63,6 +63,8 @@ Tip = namedtuple("Tip", ["msg", "pos"]) +# enable debug logging +logging.basicConfig(level=logging.DEBUG) class Shortcuts(QDialog): """Opens a window displaying available napari-deeplabcut shortcuts""" @@ -985,7 +987,7 @@ def _remap_frame_indices(self, layer): def on_insert(self, event): layer = event.source[-1] - self.update_layer(layer, event) + self.update_layer(layer, event) # this was changed while trying to update the colors of the results layer, might be changed back def update_layer(self, layer, event=None): logging.debug(f"Inserting Layer {layer}") From 1dc9f8cc70d5e28bab08c57ad391634873f3aeba Mon Sep 17 00:00:00 2001 From: Niels Poulsen Date: Fri, 26 Apr 2024 19:30:50 +0200 Subject: [PATCH 12/34] attempt 1 --- src/napari_deeplabcut/_tracking_worker.py | 137 ++++++++++++++++++++-- 1 file changed, 125 insertions(+), 12 deletions(-) diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index 024fffb..1a9a456 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -7,6 +7,9 @@ import napari import numpy as np +import pandas as pd +import torch +from cotracker.predictor import CoTrackerOnlinePredictor from napari._qt.qthreading import GeneratorWorker from qtpy.QtCore import Signal from qtpy.QtWidgets import ( @@ -177,7 +180,21 @@ def _start(self): self.start_button.setText("Running... Click to stop") def _setup_worker(self): - self._worker = TrackingWorker() + metadata = self.keypoint_layer_dropdown.layer().metadata + properties = self.keypoint_layer_dropdown.layer().properties + keypoint_cord = self.keypoint_layer_dropdown.layer_data() + frames = self.video_layer_dropdown.layer_data() + + self._worker = TrackingWorker( + # metadata["metadata"]["root"], + # metadata["metadata"]["images"], + metadata["root"], + metadata["paths"], + properties["label"], + properties["id"], + frames, + keypoint_cord, + ) self._worker.started.connect(self._on_start) @@ -190,15 +207,12 @@ def _setup_worker(self): self._worker.errored.connect(partial(self._on_error)) self._worker.finished.connect(self._on_finish) - keypoint_cord = self.keypoint_layer_dropdown.layer_data() - frames = self.video_layer_dropdown.layer_data() - - def _display_results(self, results): """Display the results in the viewer, using the method already implemented in the viewer.""" - path_test = "C:/Users/Cyril/Desktop/Code/DeepLabCut/examples/openfield-Pranav-2018-10-30/labeled-data/m4s1/CollectedData_Pranav.h5" + # path_test = "C:/Users/Cyril/Desktop/Code/DeepLabCut/examples/openfield-Pranav-2018-10-30/labeled-data/m4s1/CollectedData_Pranav.h5" from napari_deeplabcut._reader import read_hdf + path_test = results keypoint_data, metadata, _ = read_hdf(path_test)[0] # hdf data contains : keypoint data, metadata, and "points" # we want to create a points layer from the keypoint data @@ -287,17 +301,23 @@ def __init__(self, parent=None): class TrackingWorker(GeneratorWorker): """A custom worker to run tracking in.""" - def __init__(self, config=None): + def __init__(self, root, image_paths, bodyparts, individuals, video, keypoints): """Creates a TrackingWorker.""" - # super().__init__(self.run_tracking) #### TODO MUST BE CHANGED WHEN REAL TRACKING IS IMPLEMENTED - super().__init__(self.fake_tracking) + super().__init__(self.run_tracking) # TODO MUST BE CHANGED WHEN REAL TRACKING IS IMPLEMENTED + # super().__init__(self.fake_tracking) + self._root = root + self._image_paths = image_paths + self._bodyparts = bodyparts + self._individuals = individuals + self._video = video + self._keypoints = keypoints self._signals = LogSignal() self.log_signal = self._signals.log_signal self.log_w_replace_signal = self._signals.log_w_replace_signal self.warn_signal = self._signals.warn_signal self.error_signal = self._signals.error_signal - self.config = config # use if needed + self.config = None # config # use if needed def log(self, msg): """Log a message.""" @@ -313,11 +333,43 @@ def warn(self, msg): def run_tracking( self, - video: np.ndarray, - keypoints: np.ndarray, + # video: np.ndarray, + # keypoints: np.ndarray, ): """Run the tracking.""" self.log("Started tracking") + self.log(self._video) + self.log(self._keypoints) + tracks = cotrack_online(self, self._video, self._keypoints) + self.log("Finished tracking") + track_path = Path(self._root) / "TrackedData.h5" + self.save_tracking_data(track_path, tracks, "CoTracker") + self.log("Finished saving") + yield track_path + + def save_tracking_data(self, path: Path, tracks: np.ndarray, scorer: str) -> None: + levels = ["scorer", "individuals", "bodyparts", "coords"] + kpt_entries = ["x", "y"] + columns = [] + for i in self._individuals: + for b in self._bodyparts: + columns += [(scorer, i, b, entry) for entry in kpt_entries] + + index = [] + for img_path in self._image_paths: + if isinstance(img_path, str): + index.append(tuple(Path(img_path).parts)) + elif isinstance(img_path, tuple): + index.append(img_path) + else: + raise ValueError(f"Incorrect image path format: {img_path}") + + dataframe = pd.DataFrame( + data=tracks.reshape((len(tracks), -1)), + index=pd.MultiIndex.from_tuples(index), + columns=pd.MultiIndex.from_tuples(columns, names=levels), + ) + dataframe.to_hdf(path, key="df_with_missing") def fake_tracking(self): """Fake tracking for testing purposes.""" @@ -326,6 +378,67 @@ def fake_tracking(self): yield i + 1 +# TODO: REQUIRES TO RUN pip install src/co-tracker +def cotrack_online( + w, + video, + keypoints, + device: str = "cpu", +) -> np.ndarray: + w.log("COTRACKING") + w.log(video.shape) + w.log(keypoints.shape) + def _process_step(window_frames, is_first_step, queries): + video_chunk = ( + torch.tensor(np.stack(window_frames[-model.step * 2:]), device=device) + .float() + .permute(0, 3, 1, 2)[None] + ) # (1, T, 3, H, W) + return model(video_chunk, is_first_step=is_first_step, queries=queries[None]) + + # model = CoTrackerOnlinePredictor( + # checkpoint=Path( + # "/home/lucas/Projects/deeplabcut-tracking/models/cotracker2.pth" + # ) + # ) + n_frames = len(video) + n_animals, n_keypoints = keypoints.shape[:2] + + model = torch.hub.load("facebookresearch/co-tracker", "cotracker2_online") + model = model.to(device) + video = torch.from_numpy(video).permute(0, 3, 1, 2).unsqueeze(0).float() + window_frames = [] + + queries = np.zeros((n_animals * n_keypoints, 3)) + queries[:, 1:] = keypoints.reshape((-1, 2)) + queries = torch.from_numpy(queries).to(device).float() + + # Iterating over video frames, processing one window at a time: + is_first_step = True + i = 0 + for i, frame in enumerate(video[0]): + frame = frame.permute(1, 2, 0) + if i % model.step == 0 and i != 0: + pred_tracks, pred_visibility = _process_step( + window_frames, + is_first_step, + queries=queries, + ) + is_first_step = False + window_frames.append(frame) + w.log("DONE WITH FRAME") + + # Processing final frames in case video length is not a multiple of model.step + # TODO: Use visibility + pred_tracks, pred_visibility = _process_step( + window_frames[-(i % model.step) - model.step - 1:], + is_first_step, + queries=queries, + ) + tracks = pred_tracks.squeeze().cpu().numpy() + return tracks.reshape((n_frames, n_animals, n_keypoints, 2)) + + def track_mock( video: np.ndarray, keypoints: np.ndarray, From 925e8c8e1e8d11bc2c948ee0e7abddcf5234872a Mon Sep 17 00:00:00 2001 From: Niels Poulsen Date: Fri, 26 Apr 2024 19:47:30 +0200 Subject: [PATCH 13/34] closer --- src/napari_deeplabcut/_tracking_worker.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index 1a9a456..d0d702f 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -338,9 +338,11 @@ def run_tracking( ): """Run the tracking.""" self.log("Started tracking") - self.log(self._video) - self.log(self._keypoints) - tracks = cotrack_online(self, self._video, self._keypoints) + with open("log.txt", "w") as f: + f.write(f"{self._video.shape}") + f.write(f"{self._keypoints.shape}") + + tracks = cotrack_online(self, np.array(self._video), np.array(self._keypoints)) self.log("Finished tracking") track_path = Path(self._root) / "TrackedData.h5" self.save_tracking_data(track_path, tracks, "CoTracker") @@ -388,6 +390,15 @@ def cotrack_online( w.log("COTRACKING") w.log(video.shape) w.log(keypoints.shape) + k = keypoints[keypoints[:, 0] == 0][:, 1:] + with open("log_cotrack.txt", "w") as f: + f.write(f"video={video.shape}\n") + f.write(f"keypoints={keypoints.shape}\n") + f.write(f"{keypoints}\n") + f.write(f"k={k.shape}\n") + f.write(f"{k}\n") + keypoints = k.reshape((2, 4, 2)) + def _process_step(window_frames, is_first_step, queries): video_chunk = ( torch.tensor(np.stack(window_frames[-model.step * 2:]), device=device) From 2d653f7806ea7459ccdcce657acef73213fc8258 Mon Sep 17 00:00:00 2001 From: Niels Poulsen Date: Fri, 26 Apr 2024 19:49:58 +0200 Subject: [PATCH 14/34] oops --- .gitmodules | 3 +++ 1 file changed, 3 insertions(+) create mode 100644 .gitmodules diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000..fbd3a51 --- /dev/null +++ b/.gitmodules @@ -0,0 +1,3 @@ +[submodule "src/co-tracker"] + path = src/co-tracker + url = https://github.com/facebookresearch/co-tracker From f903b2eb330e90cef501c92757c2e35678267317 Mon Sep 17 00:00:00 2001 From: Niels Poulsen Date: Fri, 26 Apr 2024 20:14:12 +0200 Subject: [PATCH 15/34] it works :) --- src/napari_deeplabcut/_tracking_worker.py | 45 +++++++++++++++++++---- 1 file changed, 38 insertions(+), 7 deletions(-) diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index d0d702f..393b71f 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -212,7 +212,7 @@ def _display_results(self, results): # path_test = "C:/Users/Cyril/Desktop/Code/DeepLabCut/examples/openfield-Pranav-2018-10-30/labeled-data/m4s1/CollectedData_Pranav.h5" from napari_deeplabcut._reader import read_hdf - path_test = results + path_test = str(results) keypoint_data, metadata, _ = read_hdf(path_test)[0] # hdf data contains : keypoint data, metadata, and "points" # we want to create a points layer from the keypoint data @@ -242,7 +242,7 @@ def _on_yield(self, results): self._display_results(results) ############################ self.log.print_and_log(f"Yielded {results}") - self._update_progress_bar(results, 10) + # self._update_progress_bar(results, 10) ############################ def _on_start(self): @@ -343,6 +343,8 @@ def run_tracking( f.write(f"{self._keypoints.shape}") tracks = cotrack_online(self, np.array(self._video), np.array(self._keypoints)) + with open("log_finished_tracking.txt", "w") as f: + f.write(f"Done! {tracks.shape}") self.log("Finished tracking") track_path = Path(self._root) / "TrackedData.h5" self.save_tracking_data(track_path, tracks, "CoTracker") @@ -353,9 +355,11 @@ def save_tracking_data(self, path: Path, tracks: np.ndarray, scorer: str) -> Non levels = ["scorer", "individuals", "bodyparts", "coords"] kpt_entries = ["x", "y"] columns = [] - for i in self._individuals: - for b in self._bodyparts: - columns += [(scorer, i, b, entry) for entry in kpt_entries] + # for i in self._individuals: + # for b in self._bodyparts: + # columns += [(scorer, i, b, entry) for entry in kpt_entries] + for i, b in zip(self._individuals[:8], self._bodyparts[:8]): + columns += [(scorer, i, b, entry) for entry in kpt_entries] index = [] for img_path in self._image_paths: @@ -366,6 +370,13 @@ def save_tracking_data(self, path: Path, tracks: np.ndarray, scorer: str) -> Non else: raise ValueError(f"Incorrect image path format: {img_path}") + with open("log_df.txt", "w") as f: + f.write(f"{tracks.reshape((len(tracks), -1)).shape}\n") + f.write(f"{len(index)}\n") + f.write(f"{len(columns)}\n") + f.write(f"{self._individuals}\n") + f.write(f"{self._bodyparts}\n") + dataframe = pd.DataFrame( data=tracks.reshape((len(tracks), -1)), index=pd.MultiIndex.from_tuples(index), @@ -398,8 +409,18 @@ def cotrack_online( f.write(f"k={k.shape}\n") f.write(f"{k}\n") keypoints = k.reshape((2, 4, 2)) + k = np.zeros(keypoints.shape) + k[..., 0] = keypoints[..., 1] + k[..., 1] = keypoints[..., 0] + keypoints = k def _process_step(window_frames, is_first_step, queries): + with open("log_window_frames.txt", "w") as f: + f.write(f"{len(window_frames)}\n") + f.write(f"{model.step}\n") + f.write(f"{-model.step * 2}\n") + f.write(f"is_first_step={is_first_step}\n") + video_chunk = ( torch.tensor(np.stack(window_frames[-model.step * 2:]), device=device) .float() @@ -436,8 +457,8 @@ def _process_step(window_frames, is_first_step, queries): queries=queries, ) is_first_step = False - window_frames.append(frame) - w.log("DONE WITH FRAME") + window_frames.append(frame) + w.log("DONE WITH FRAME") # Processing final frames in case video length is not a multiple of model.step # TODO: Use visibility @@ -446,7 +467,17 @@ def _process_step(window_frames, is_first_step, queries): is_first_step, queries=queries, ) + + with open("log_pred_tracks.txt", "w") as f: + f.write(f"{len(pred_tracks)}\n") + f.write(f"{pred_tracks.shape}\n") + tracks = pred_tracks.squeeze().cpu().numpy() + with open("log_pred_tracks_2.txt", "w") as f: + f.write(f"{len(tracks)}\n") + f.write(f"{tracks.shape}\n") + f.write(f"{(n_frames, n_animals, n_keypoints, 2)}\n") + return tracks.reshape((n_frames, n_animals, n_keypoints, 2)) From ec0333ea434c3181ead7d15a977076d48fdbfd64 Mon Sep 17 00:00:00 2001 From: Niels Poulsen Date: Sat, 27 Apr 2024 12:10:25 +0200 Subject: [PATCH 16/34] :) --- src/napari_deeplabcut/_tracking_worker.py | 45 +++++++++++++---------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index 393b71f..819c047 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -181,17 +181,20 @@ def _start(self): def _setup_worker(self): metadata = self.keypoint_layer_dropdown.layer().metadata - properties = self.keypoint_layer_dropdown.layer().properties keypoint_cord = self.keypoint_layer_dropdown.layer_data() frames = self.video_layer_dropdown.layer_data() + header = metadata["header"] + bodyparts = header.bodyparts + individuals_ids = header.individuals + self._worker = TrackingWorker( # metadata["metadata"]["root"], # metadata["metadata"]["images"], metadata["root"], metadata["paths"], - properties["label"], - properties["id"], + bodyparts, + individuals_ids, frames, keypoint_cord, ) @@ -331,18 +334,20 @@ def warn(self, msg): """Log a warning.""" self.warn_signal.emit(msg) - def run_tracking( - self, - # video: np.ndarray, - # keypoints: np.ndarray, - ): + def run_tracking(self): """Run the tracking.""" self.log("Started tracking") with open("log.txt", "w") as f: f.write(f"{self._video.shape}") f.write(f"{self._keypoints.shape}") - tracks = cotrack_online(self, np.array(self._video), np.array(self._keypoints)) + tracks = cotrack_online( + self.log, + np.array(self._video), + np.array(self._keypoints), + len(self._individuals), + len(self._bodyparts), + ) with open("log_finished_tracking.txt", "w") as f: f.write(f"Done! {tracks.shape}") self.log("Finished tracking") @@ -355,11 +360,11 @@ def save_tracking_data(self, path: Path, tracks: np.ndarray, scorer: str) -> Non levels = ["scorer", "individuals", "bodyparts", "coords"] kpt_entries = ["x", "y"] columns = [] - # for i in self._individuals: - # for b in self._bodyparts: - # columns += [(scorer, i, b, entry) for entry in kpt_entries] - for i, b in zip(self._individuals[:8], self._bodyparts[:8]): - columns += [(scorer, i, b, entry) for entry in kpt_entries] + for i in self._individuals: + for b in self._bodyparts: + columns += [(scorer, i, b, entry) for entry in kpt_entries] + # for i, b in zip(self._individuals[:8], self._bodyparts[:8]): + # columns += [(scorer, i, b, entry) for entry in kpt_entries] index = [] for img_path in self._image_paths: @@ -393,14 +398,14 @@ def fake_tracking(self): # TODO: REQUIRES TO RUN pip install src/co-tracker def cotrack_online( - w, + log, video, keypoints, + n_animals, + n_bodyparts, device: str = "cpu", ) -> np.ndarray: - w.log("COTRACKING") - w.log(video.shape) - w.log(keypoints.shape) + log("COTRACKING") k = keypoints[keypoints[:, 0] == 0][:, 1:] with open("log_cotrack.txt", "w") as f: f.write(f"video={video.shape}\n") @@ -408,7 +413,7 @@ def cotrack_online( f.write(f"{keypoints}\n") f.write(f"k={k.shape}\n") f.write(f"{k}\n") - keypoints = k.reshape((2, 4, 2)) + keypoints = k.reshape((n_animals, n_bodyparts, 2)) k = np.zeros(keypoints.shape) k[..., 0] = keypoints[..., 1] k[..., 1] = keypoints[..., 0] @@ -458,7 +463,7 @@ def _process_step(window_frames, is_first_step, queries): ) is_first_step = False window_frames.append(frame) - w.log("DONE WITH FRAME") + log(f"Finished batch {i}") # Processing final frames in case video length is not a multiple of model.step # TODO: Use visibility From 350bebedb0ebca9efad06c5be4c67993da54045b Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 27 Apr 2024 12:11:42 +0200 Subject: [PATCH 17/34] Update .gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index 73d56d3..cb794c2 100644 --- a/.gitignore +++ b/.gitignore @@ -82,3 +82,5 @@ venv/ # written by setuptools_scm **/_version.py + +*.txt \ No newline at end of file From 623662ae5a0c3229e60f9d8a2a350a9608dda584 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 27 Apr 2024 12:34:56 +0200 Subject: [PATCH 18/34] Redo config setup --- src/napari_deeplabcut/_tracking_worker.py | 91 ++++++++++++++++------- src/napari_deeplabcut/napari.yaml | 5 -- 2 files changed, 66 insertions(+), 30 deletions(-) diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index 819c047..9cb45d4 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -4,7 +4,7 @@ # - Prepare I/O with the actual tracking backend from functools import partial from pathlib import Path - +from dataclasses import dataclass import napari import numpy as np import pandas as pd @@ -29,7 +29,34 @@ add_widgets, get_time, ) -from napari.layers.utils.layer_utils import _features_to_properties + +@dataclass +class TrackingConfig: + ### Data ### + video: np.ndarray + keypoints: np.ndarray + # result_layer: napari.layers.Points + ### Metadata ### + root: str # path to the video + paths: list # list of paths to the video frames + bodyparts: list # list of bodyparts + individuals_ids: list # list of individuals + ### Config from data ### + n_frames: int + n_animals: int + n_keypoints: int + ### User config ### + retrack_frame: int = None + method : str = "CoTracker" # change when adding PIPS++ + device: str = "cpu" + +@dataclass +class TrackingResults: # Add anything relevant to be yielded by the worker here + """Used to update the results and progress bar. Is yielded by the worker.""" + result_keypoints: np.ndarray = None + layer_metadata: dict = None + hdf_path : str = None + pbar_update : tuple = None class TrackingModule(QWidget, metaclass=QWidgetSingleton): """Plugin for tracking.""" @@ -136,8 +163,8 @@ def _display_status_report(self): self.container_docked = True self.log.setVisible(True) - self.progress.setVisible(True) - self.progress.setValue(0) + # self.progress.setVisible(True) + # self.progress.setValue(0) def _update_progress_bar(self, current_frame, total_frame): """Update the progress bar.""" @@ -188,16 +215,30 @@ def _setup_worker(self): bodyparts = header.bodyparts individuals_ids = header.individuals - self._worker = TrackingWorker( - # metadata["metadata"]["root"], - # metadata["metadata"]["images"], - metadata["root"], - metadata["paths"], - bodyparts, - individuals_ids, - frames, - keypoint_cord, + self.worker_config = TrackingConfig( + video=frames, + keypoints=keypoint_cord, + root=metadata["root"], + paths=metadata["paths"], + bodyparts=bodyparts, + individuals_ids=individuals_ids, + n_frames=len(frames), + n_animals=len(individuals_ids), + n_keypoints=len(bodyparts), ) + self._worker = TrackingWorker(self.worker_config) + + + # self._worker = TrackingWorker( + # # metadata["metadata"]["root"], + # # metadata["metadata"]["images"], + # metadata["root"], + # metadata["paths"], + # bodyparts, + # individuals_ids, + # frames, + # keypoint_cord, + # ) self._worker.started.connect(self._on_start) @@ -304,24 +345,24 @@ def __init__(self, parent=None): class TrackingWorker(GeneratorWorker): """A custom worker to run tracking in.""" - def __init__(self, root, image_paths, bodyparts, individuals, video, keypoints): + def __init__(self, config: TrackingConfig): """Creates a TrackingWorker.""" - super().__init__(self.run_tracking) # TODO MUST BE CHANGED WHEN REAL TRACKING IS IMPLEMENTED - # super().__init__(self.fake_tracking) - self._root = root - self._image_paths = image_paths - self._bodyparts = bodyparts - self._individuals = individuals - self._video = video - self._keypoints = keypoints + super().__init__(self.run_tracking) + + self.config = config + + self._root = config.root + self._image_paths = config.paths + self._bodyparts = config.bodyparts + self._individuals = config.individuals_ids + self._video = config.video + self._keypoints = config.keypoints self._signals = LogSignal() self.log_signal = self._signals.log_signal self.log_w_replace_signal = self._signals.log_w_replace_signal self.warn_signal = self._signals.warn_signal self.error_signal = self._signals.error_signal - self.config = None # config # use if needed - def log(self, msg): """Log a message.""" self.log_signal.emit(msg) @@ -463,7 +504,7 @@ def _process_step(window_frames, is_first_step, queries): ) is_first_step = False window_frames.append(frame) - log(f"Finished batch {i}") + log(f"Finished frame {i}") # Processing final frames in case video length is not a multiple of model.step # TODO: Use visibility diff --git a/src/napari_deeplabcut/napari.yaml b/src/napari_deeplabcut/napari.yaml index 02f8907..f98c620 100644 --- a/src/napari_deeplabcut/napari.yaml +++ b/src/napari_deeplabcut/napari.yaml @@ -26,9 +26,6 @@ contributions: - id: napari-deeplabcut.make_keypoint_controls python_name: napari_deeplabcut._widgets:KeypointControls title: Make keypoint controls - - id: napari-deeplabcut.tracking_demo - python_name: napari_deeplabcut._tracking_worker:TrackingModule - title: Tracking demo readers: - command: napari-deeplabcut.get_hdf_reader accepts_directories: false @@ -55,5 +52,3 @@ contributions: widgets: - command: napari-deeplabcut.make_keypoint_controls display_name: Keypoint controls - - command: napari-deeplabcut.tracking_demo - display_name: Tracking demo From 171b4a6b66f4284a45dcfdbc3e82a444c5e7f1cc Mon Sep 17 00:00:00 2001 From: Niels Poulsen Date: Sat, 27 Apr 2024 12:13:46 +0200 Subject: [PATCH 19/34] minor log changes --- src/napari_deeplabcut/_tracking_worker.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index 9cb45d4..64b8511 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -446,7 +446,7 @@ def cotrack_online( n_bodyparts, device: str = "cpu", ) -> np.ndarray: - log("COTRACKING") + log("Running CoTracker") k = keypoints[keypoints[:, 0] == 0][:, 1:] with open("log_cotrack.txt", "w") as f: f.write(f"video={video.shape}\n") @@ -454,6 +454,7 @@ def cotrack_online( f.write(f"{keypoints}\n") f.write(f"k={k.shape}\n") f.write(f"{k}\n") + keypoints = k.reshape((n_animals, n_bodyparts, 2)) k = np.zeros(keypoints.shape) k[..., 0] = keypoints[..., 1] @@ -474,11 +475,6 @@ def _process_step(window_frames, is_first_step, queries): ) # (1, T, 3, H, W) return model(video_chunk, is_first_step=is_first_step, queries=queries[None]) - # model = CoTrackerOnlinePredictor( - # checkpoint=Path( - # "/home/lucas/Projects/deeplabcut-tracking/models/cotracker2.pth" - # ) - # ) n_frames = len(video) n_animals, n_keypoints = keypoints.shape[:2] From c6d3de647a7e11c36e876626062ebaa64d6b37e7 Mon Sep 17 00:00:00 2001 From: Niels Poulsen Date: Sat, 27 Apr 2024 12:57:41 +0200 Subject: [PATCH 20/34] cleaned tracking, added init_frame option --- src/napari_deeplabcut/_tracking_worker.py | 26 ++++++++++------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index 64b8511..0b1cc31 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -382,12 +382,19 @@ def run_tracking(self): f.write(f"{self._video.shape}") f.write(f"{self._keypoints.shape}") + init_frame = 0 + + video_frames = np.array(self._video) + video_frames = video_frames[init_frame:] + + keypoints = np.array(self._keypoints) + keypoints = keypoints[keypoints[:, 0] == init_frame][:, [2, 1]] + keypoints = keypoints.reshape((len(self._individuals), len(self._bodyparts), 2)) + tracks = cotrack_online( self.log, - np.array(self._video), - np.array(self._keypoints), - len(self._individuals), - len(self._bodyparts), + video_frames, + keypoints, ) with open("log_finished_tracking.txt", "w") as f: f.write(f"Done! {tracks.shape}") @@ -442,24 +449,13 @@ def cotrack_online( log, video, keypoints, - n_animals, - n_bodyparts, device: str = "cpu", ) -> np.ndarray: log("Running CoTracker") - k = keypoints[keypoints[:, 0] == 0][:, 1:] with open("log_cotrack.txt", "w") as f: f.write(f"video={video.shape}\n") f.write(f"keypoints={keypoints.shape}\n") f.write(f"{keypoints}\n") - f.write(f"k={k.shape}\n") - f.write(f"{k}\n") - - keypoints = k.reshape((n_animals, n_bodyparts, 2)) - k = np.zeros(keypoints.shape) - k[..., 0] = keypoints[..., 1] - k[..., 1] = keypoints[..., 0] - keypoints = k def _process_step(window_frames, is_first_step, queries): with open("log_window_frames.txt", "w") as f: From cc87fd2e19ab8c485bd88b1d8ef89db09641a320 Mon Sep 17 00:00:00 2001 From: Niels Poulsen Date: Sat, 27 Apr 2024 13:04:02 +0200 Subject: [PATCH 21/34] added init frame parameter --- src/napari_deeplabcut/_tracking_worker.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index 0b1cc31..5b8ae1a 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -399,7 +399,7 @@ def run_tracking(self): with open("log_finished_tracking.txt", "w") as f: f.write(f"Done! {tracks.shape}") self.log("Finished tracking") - track_path = Path(self._root) / "TrackedData.h5" + track_path = Path(self._root) / f"TrackedData_start{init_frame}.h5" self.save_tracking_data(track_path, tracks, "CoTracker") self.log("Finished saving") yield track_path @@ -414,8 +414,9 @@ def save_tracking_data(self, path: Path, tracks: np.ndarray, scorer: str) -> Non # for i, b in zip(self._individuals[:8], self._bodyparts[:8]): # columns += [(scorer, i, b, entry) for entry in kpt_entries] + init_frame = 0 index = [] - for img_path in self._image_paths: + for img_path in self._image_paths[init_frame:]: if isinstance(img_path, str): index.append(tuple(Path(img_path).parts)) elif isinstance(img_path, tuple): From 2231eeb60fc3da77a90953b82338a59168acc6ca Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 27 Apr 2024 13:22:51 +0200 Subject: [PATCH 22/34] First functional prototype of retracking --- src/napari_deeplabcut/_tracking_worker.py | 36 ++++++++++++++++++----- 1 file changed, 29 insertions(+), 7 deletions(-) diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index 5b8ae1a..a0cda91 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -46,7 +46,7 @@ class TrackingConfig: n_animals: int n_keypoints: int ### User config ### - retrack_frame: int = None + retrack_frame_id: int = None method : str = "CoTracker" # change when adding PIPS++ device: str = "cpu" @@ -66,7 +66,9 @@ def __init__(self, napari_viewer: "napari.viewer.Viewer", parent=None): super().__init__(parent=parent) self._viewer = napari_viewer self._worker = None - self._keypoint_layer = None + self._video_layer : napari.layers.Image = None + self._keypoint_layer : napari.layers.Points = None + self.result_layer : napari.layers.Points = None ### Widgets ### self.video_layer_dropdown = LayerSelecter( self._viewer, @@ -99,6 +101,7 @@ def __init__(self, napari_viewer: "napari.viewer.Viewer", parent=None): self.log.setVisible(False) """Read-only display for process-related info. Use only for info destined to user.""" self._build() + self._viewer.dims.events.current_step.connect(self._update_start_button_display) # Use @property to get/set the keypoint layer @property @@ -165,6 +168,21 @@ def _display_status_report(self): self.log.setVisible(True) # self.progress.setVisible(True) # self.progress.setValue(0) + + def _update_start_button_display(self): + """Update the start button display.""" + if self._worker is None: + return + if self._worker.is_running: + return + if not self._worker.is_running and self.result_layer is not None: + current_frame = self._viewer.dims.current_step[0] + if current_frame == 0: + self.start_button.setText("Start") + return + self.start_button.setText(f"Retrack from frame {current_frame}") + else: + self.start_button.setText("Start") def _update_progress_bar(self, current_frame, total_frame): """Update the progress bar.""" @@ -214,6 +232,8 @@ def _setup_worker(self): header = metadata["header"] bodyparts = header.bodyparts individuals_ids = header.individuals + + current_frame = self._viewer.dims.current_step[0] self.worker_config = TrackingConfig( video=frames, @@ -225,6 +245,7 @@ def _setup_worker(self): n_frames=len(frames), n_animals=len(individuals_ids), n_keypoints=len(bodyparts), + retrack_frame_id=current_frame if self.result_layer is not None else None, ) self._worker = TrackingWorker(self.worker_config) @@ -262,7 +283,7 @@ def _display_results(self, results): # we want to create a points layer from the keypoint data # layer properties (dict) should be populated with metadata print(metadata) - layer = self._viewer.add_points( + return self._viewer.add_points( ### data ### keypoint_data, name="keypoints_hdf_test", @@ -279,11 +300,12 @@ def _display_results(self, results): edge_width_is_relative=metadata["edge_width_is_relative"], size=metadata["size"], ) + def _on_yield(self, results): # TODO : display the results in the viewer # Testing version where an int i is yielded - self._display_results(results) + self.result_layer = self._display_results(results) ############################ self.log.print_and_log(f"Yielded {results}") # self._update_progress_bar(results, 10) @@ -377,12 +399,12 @@ def warn(self, msg): def run_tracking(self): """Run the tracking.""" - self.log("Started tracking") with open("log.txt", "w") as f: f.write(f"{self._video.shape}") f.write(f"{self._keypoints.shape}") - init_frame = 0 + init_frame = self.config.retrack_frame_id if self.config.retrack_frame_id is not None else 0 + self.log(f"Started tracking from frame {init_frame}") video_frames = np.array(self._video) video_frames = video_frames[init_frame:] @@ -399,7 +421,7 @@ def run_tracking(self): with open("log_finished_tracking.txt", "w") as f: f.write(f"Done! {tracks.shape}") self.log("Finished tracking") - track_path = Path(self._root) / f"TrackedData_start{init_frame}.h5" + track_path = Path(self._root) / f"TrackedData_frame_{init_frame}.h5" self.save_tracking_data(track_path, tracks, "CoTracker") self.log("Finished saving") yield track_path From 104c6abaddf7711118c9b6773f907e72b7cb3032 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 27 Apr 2024 13:24:56 +0200 Subject: [PATCH 23/34] Fix saving with frame id --- src/napari_deeplabcut/_tracking_worker.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index a0cda91..87cc1ca 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -426,7 +426,7 @@ def run_tracking(self): self.log("Finished saving") yield track_path - def save_tracking_data(self, path: Path, tracks: np.ndarray, scorer: str) -> None: + def save_tracking_data(self, path: Path, tracks: np.ndarray, scorer: str, frame: int = 0) -> None: levels = ["scorer", "individuals", "bodyparts", "coords"] kpt_entries = ["x", "y"] columns = [] @@ -436,9 +436,8 @@ def save_tracking_data(self, path: Path, tracks: np.ndarray, scorer: str) -> Non # for i, b in zip(self._individuals[:8], self._bodyparts[:8]): # columns += [(scorer, i, b, entry) for entry in kpt_entries] - init_frame = 0 index = [] - for img_path in self._image_paths[init_frame:]: + for img_path in self._image_paths[frame:]: if isinstance(img_path, str): index.append(tuple(Path(img_path).parts)) elif isinstance(img_path, tuple): From c69b397b6bfd43a6f180f6a91077b8e1c171717f Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 27 Apr 2024 13:37:58 +0200 Subject: [PATCH 24/34] Fix error in saving retracking --- src/napari_deeplabcut/_tracking_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index 87cc1ca..2c9409f 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -422,7 +422,7 @@ def run_tracking(self): f.write(f"Done! {tracks.shape}") self.log("Finished tracking") track_path = Path(self._root) / f"TrackedData_frame_{init_frame}.h5" - self.save_tracking_data(track_path, tracks, "CoTracker") + self.save_tracking_data(track_path, tracks, "CoTracker", frame=init_frame) self.log("Finished saving") yield track_path From 2fd3f402f97b947544c82e506a82ebe58996b4e7 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 27 Apr 2024 13:56:29 +0200 Subject: [PATCH 25/34] Update _tracking_worker.py --- src/napari_deeplabcut/_tracking_worker.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index 2c9409f..498a65d 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -48,7 +48,7 @@ class TrackingConfig: ### User config ### retrack_frame_id: int = None method : str = "CoTracker" # change when adding PIPS++ - device: str = "cpu" + device: str = "cpu" if not torch.cuda.is_available() else "cuda" @dataclass class TrackingResults: # Add anything relevant to be yielded by the worker here @@ -84,6 +84,8 @@ def __init__(self, napari_viewer: "napari.viewer.Viewer", parent=None): ) self.start_button = QPushButton("Start tracking") self.start_button.clicked.connect(self._start) + self.enable_retracking = False + ############################# # status report docked widget self.container_docked = False # check if already docked @@ -171,10 +173,9 @@ def _display_status_report(self): def _update_start_button_display(self): """Update the start button display.""" - if self._worker is None: - return if self._worker.is_running: return + if not self._worker.is_running and self.result_layer is not None: current_frame = self._viewer.dims.current_step[0] if current_frame == 0: @@ -234,6 +235,9 @@ def _setup_worker(self): individuals_ids = header.individuals current_frame = self._viewer.dims.current_step[0] + retrack_frame_id = None + if current_frame != 0 and self.enable_retracking: + retrack_frame_id = current_frame self.worker_config = TrackingConfig( video=frames, @@ -245,7 +249,7 @@ def _setup_worker(self): n_frames=len(frames), n_animals=len(individuals_ids), n_keypoints=len(bodyparts), - retrack_frame_id=current_frame if self.result_layer is not None else None, + retrack_frame_id=retrack_frame_id, ) self._worker = TrackingWorker(self.worker_config) @@ -286,7 +290,7 @@ def _display_results(self, results): return self._viewer.add_points( ### data ### keypoint_data, - name="keypoints_hdf_test", + name=f"Tracked keypoints - frame {self._worker.config.retrack_frame_id}", metadata=metadata["metadata"], # features=metadata["properties"], properties=metadata["properties"], @@ -332,6 +336,7 @@ def _on_finish(self): self.start_button.setText("Start") self._worker = None + self.enable_retracking = True return True # signal clean exit @@ -417,6 +422,7 @@ def run_tracking(self): self.log, video_frames, keypoints, + device=self.config.device, ) with open("log_finished_tracking.txt", "w") as f: f.write(f"Done! {tracks.shape}") From b4dcb6b02ad983206591bd57d101c870a95ef1a5 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 27 Apr 2024 13:58:23 +0200 Subject: [PATCH 26/34] Remove unnecessary conversion --- src/napari_deeplabcut/_tracking_worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index 498a65d..5747912 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -411,8 +411,8 @@ def run_tracking(self): init_frame = self.config.retrack_frame_id if self.config.retrack_frame_id is not None else 0 self.log(f"Started tracking from frame {init_frame}") - video_frames = np.array(self._video) - video_frames = video_frames[init_frame:] + + video_frames = self._video[init_frame:] keypoints = np.array(self._keypoints) keypoints = keypoints[keypoints[:, 0] == init_frame][:, [2, 1]] From d8f94be44c690ead4c206f2ee4ae93e6eeae5e8e Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 27 Apr 2024 14:00:56 +0200 Subject: [PATCH 27/34] Update _tracking_worker.py --- src/napari_deeplabcut/_tracking_worker.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index 5747912..6dd55ee 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -173,10 +173,14 @@ def _display_status_report(self): def _update_start_button_display(self): """Update the start button display.""" - if self._worker.is_running: - return + if self._worker is not None: + if self._worker.is_running: + return - if not self._worker.is_running and self.result_layer is not None: + if self.result_layer is not None: + if self.worker is not None: + if self.worker.is_running: + return current_frame = self._viewer.dims.current_step[0] if current_frame == 0: self.start_button.setText("Start") @@ -408,14 +412,14 @@ def run_tracking(self): f.write(f"{self._video.shape}") f.write(f"{self._keypoints.shape}") - init_frame = self.config.retrack_frame_id if self.config.retrack_frame_id is not None else 0 - self.log(f"Started tracking from frame {init_frame}") + retrack_frame = self.config.retrack_frame_id if self.config.retrack_frame_id is not None else 0 + self.log(f"Started tracking from frame {retrack_frame}") - video_frames = self._video[init_frame:] + video_frames = self._video[retrack_frame:] keypoints = np.array(self._keypoints) - keypoints = keypoints[keypoints[:, 0] == init_frame][:, [2, 1]] + keypoints = keypoints[keypoints[:, 0] == retrack_frame][:, [2, 1]] keypoints = keypoints.reshape((len(self._individuals), len(self._bodyparts), 2)) tracks = cotrack_online( @@ -427,8 +431,8 @@ def run_tracking(self): with open("log_finished_tracking.txt", "w") as f: f.write(f"Done! {tracks.shape}") self.log("Finished tracking") - track_path = Path(self._root) / f"TrackedData_frame_{init_frame}.h5" - self.save_tracking_data(track_path, tracks, "CoTracker", frame=init_frame) + track_path = Path(self._root) / f"TrackedData_frame_{retrack_frame}.h5" + self.save_tracking_data(track_path, tracks, "CoTracker", frame=retrack_frame) self.log("Finished saving") yield track_path From ff1965d5e5f0d6df56ce73d13750eb7b436fc737 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 27 Apr 2024 14:05:38 +0200 Subject: [PATCH 28/34] Fix worker --- src/napari_deeplabcut/_tracking_worker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index 6dd55ee..ac1557d 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -178,8 +178,8 @@ def _update_start_button_display(self): return if self.result_layer is not None: - if self.worker is not None: - if self.worker.is_running: + if self._worker is not None: + if self._worker.is_running: return current_frame = self._viewer.dims.current_step[0] if current_frame == 0: From 7858920b87ed09b7690dba3fab0f0a5dd17406be Mon Sep 17 00:00:00 2001 From: Arash Sal Moslehian Date: Sat, 27 Apr 2024 12:09:56 +0000 Subject: [PATCH 29/34] tracking --- src/napari_deeplabcut/_tracking_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index ac1557d..4fd164b 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -416,7 +416,7 @@ def run_tracking(self): self.log(f"Started tracking from frame {retrack_frame}") - video_frames = self._video[retrack_frame:] + video_frames = np.array(self._video[retrack_frame:]) keypoints = np.array(self._keypoints) keypoints = keypoints[keypoints[:, 0] == retrack_frame][:, [2, 1]] From 3e448be7cd117c850347afb05e92a708d4f89cef Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 27 Apr 2024 14:18:56 +0200 Subject: [PATCH 30/34] Functional retracking --- src/napari_deeplabcut/_tracking_worker.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index 4fd164b..cfa0574 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -291,10 +291,11 @@ def _display_results(self, results): # we want to create a points layer from the keypoint data # layer properties (dict) should be populated with metadata print(metadata) + frame_id = self._worker.config.retrack_frame_id if self._worker.config.retrack_frame_id is not None else 0 return self._viewer.add_points( ### data ### keypoint_data, - name=f"Tracked keypoints - frame {self._worker.config.retrack_frame_id}", + name=f"Tracked keypoints - frame {frame_id}", metadata=metadata["metadata"], # features=metadata["properties"], properties=metadata["properties"], @@ -388,6 +389,7 @@ def __init__(self, config: TrackingConfig): self._individuals = config.individuals_ids self._video = config.video self._keypoints = config.keypoints + self._signals = LogSignal() self.log_signal = self._signals.log_signal self.log_w_replace_signal = self._signals.log_w_replace_signal @@ -431,6 +433,7 @@ def run_tracking(self): with open("log_finished_tracking.txt", "w") as f: f.write(f"Done! {tracks.shape}") self.log("Finished tracking") + retrack_frame = 0 if retrack_frame is None else retrack_frame track_path = Path(self._root) / f"TrackedData_frame_{retrack_frame}.h5" self.save_tracking_data(track_path, tracks, "CoTracker", frame=retrack_frame) self.log("Finished saving") From 1d6476983719a9a74f94fe9341ad4b343c8df33d Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 27 Apr 2024 14:25:43 +0200 Subject: [PATCH 31/34] Fix start button after finish --- src/napari_deeplabcut/_tracking_worker.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index cfa0574..cb2d7ba 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -339,6 +339,7 @@ def _on_finish(self): self.log.print_and_log(f"\nWorker finished at {get_time()}") self.log.print_and_log("*" * 20) self.start_button.setText("Start") + self._update_start_button_display() self._worker = None self.enable_retracking = True From 2cfa3413b16cf60a700a05d2c1bda1ecc3989736 Mon Sep 17 00:00:00 2001 From: C-Achard Date: Sat, 27 Apr 2024 14:35:02 +0200 Subject: [PATCH 32/34] Add retracking when already run --- src/napari_deeplabcut/_tracking_worker.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index cb2d7ba..e9fee4c 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -85,6 +85,8 @@ def __init__(self, napari_viewer: "napari.viewer.Viewer", parent=None): self.start_button = QPushButton("Start tracking") self.start_button.clicked.connect(self._start) self.enable_retracking = False + self._check_for_retracking_availability() + self._viewer.layers.events.inserted.connect(self._check_for_retracking_availability) ############################# # status report docked widget @@ -177,7 +179,7 @@ def _update_start_button_display(self): if self._worker.is_running: return - if self.result_layer is not None: + if self.result_layer is not None and self.enable_retracking: if self._worker is not None: if self._worker.is_running: return @@ -189,6 +191,13 @@ def _update_start_button_display(self): else: self.start_button.setText("Start") + def _check_for_retracking_availability(self): + for layer in self._viewer.layers: + if "Tracked keypoints - frame 0" in layer.name: + self.enable_retracking = True + self.result_layer = layer + self._update_start_button_display() + def _update_progress_bar(self, current_frame, total_frame): """Update the progress bar.""" pbar_value = (current_frame / total_frame) * 100 @@ -435,7 +444,7 @@ def run_tracking(self): f.write(f"Done! {tracks.shape}") self.log("Finished tracking") retrack_frame = 0 if retrack_frame is None else retrack_frame - track_path = Path(self._root) / f"TrackedData_frame_{retrack_frame}.h5" + track_path = Path(self._root) / f"Tracked keypoints - frame {retrack_frame}.h5" self.save_tracking_data(track_path, tracks, "CoTracker", frame=retrack_frame) self.log("Finished saving") yield track_path From 2bacd136abb00c1687222d2d8262be32f68c0cb2 Mon Sep 17 00:00:00 2001 From: Niels Poulsen Date: Sat, 27 Apr 2024 15:00:39 +0200 Subject: [PATCH 33/34] bug fix --- src/napari_deeplabcut/_tracking_worker.py | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index e9fee4c..2d90c4d 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -30,6 +30,11 @@ get_time, ) + +def get_track_filename(start_frame: int) -> str: + return f"TrackedKeypoints_start_{start_frame}.h5" + + @dataclass class TrackingConfig: ### Data ### @@ -192,8 +197,13 @@ def _update_start_button_display(self): self.start_button.setText("Start") def _check_for_retracking_availability(self): + base_track_filename = get_track_filename(start_frame=0) + print(base_track_filename) + base_track_stem = base_track_filename.split(".h5")[0] + print("stem", base_track_stem) for layer in self._viewer.layers: - if "Tracked keypoints - frame 0" in layer.name: + print(f"Layer: '{layer.name}'") + if base_track_stem in layer.name: self.enable_retracking = True self.result_layer = layer self._update_start_button_display() @@ -304,7 +314,7 @@ def _display_results(self, results): return self._viewer.add_points( ### data ### keypoint_data, - name=f"Tracked keypoints - frame {frame_id}", + name=results.stem, metadata=metadata["metadata"], # features=metadata["properties"], properties=metadata["properties"], @@ -431,6 +441,10 @@ def run_tracking(self): video_frames = np.array(self._video[retrack_frame:]) keypoints = np.array(self._keypoints) + with open("log_frame_indices.txt", "w") as f: + f.write(f"{keypoints[:, 0]}\n\n") + f.write(f"{keypoints}") + keypoints = keypoints[keypoints[:, 0] == retrack_frame][:, [2, 1]] keypoints = keypoints.reshape((len(self._individuals), len(self._bodyparts), 2)) @@ -444,7 +458,7 @@ def run_tracking(self): f.write(f"Done! {tracks.shape}") self.log("Finished tracking") retrack_frame = 0 if retrack_frame is None else retrack_frame - track_path = Path(self._root) / f"Tracked keypoints - frame {retrack_frame}.h5" + track_path = Path(self._root) / get_track_filename(start_frame=retrack_frame) self.save_tracking_data(track_path, tracks, "CoTracker", frame=retrack_frame) self.log("Finished saving") yield track_path From 2531501192c07819c58bc94bc0c5c21e45f94067 Mon Sep 17 00:00:00 2001 From: Niels Poulsen Date: Sat, 27 Apr 2024 15:16:28 +0200 Subject: [PATCH 34/34] fix --- src/napari_deeplabcut/_tracking_worker.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/napari_deeplabcut/_tracking_worker.py b/src/napari_deeplabcut/_tracking_worker.py index 2d90c4d..a1f844f 100644 --- a/src/napari_deeplabcut/_tracking_worker.py +++ b/src/napari_deeplabcut/_tracking_worker.py @@ -460,6 +460,7 @@ def run_tracking(self): retrack_frame = 0 if retrack_frame is None else retrack_frame track_path = Path(self._root) / get_track_filename(start_frame=retrack_frame) self.save_tracking_data(track_path, tracks, "CoTracker", frame=retrack_frame) + # self.create_combined_tracking_data() self.log("Finished saving") yield track_path @@ -496,6 +497,27 @@ def save_tracking_data(self, path: Path, tracks: np.ndarray, scorer: str, frame: ) dataframe.to_hdf(path, key="df_with_missing") + def create_combined_tracking_data(self): + track_files = [ + p for p in Path(self._root).iterdir() + if p.is_file() and p.stem.startswith("TrackedKeypoints") + ] + if len(track_files) <= 1: + return + + df_tracks = { + int(p.stem.split("_")[-1]): pd.read_hdf(p, key="df_with_missing") + for p in track_files + } + df = df_tracks[0] + indices = [i for i in sorted(df_tracks.keys()) if i > 0] + for idx in indices: + new_data = df_tracks[idx] + df.iloc[idx:] = new_data.iloc[idx:] + + path = Path(self._root) / "CombinedTracks.h5" + df.to_hdf(path, key="df_with_missing") + def fake_tracking(self): """Fake tracking for testing purposes.""" for i in range(1):