diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index e8b61591..4ab1cae2 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -1,10 +1,13 @@ -FROM python:3.9-slim@sha256:5f0192a4f58a6ce99f732fe05e3b3d00f12ae62e183886bca3ebe3d202686c7f +ARG PY_VER=3.11 +ARG DISTRO=bullseye +FROM mcr.microsoft.com/devcontainers/python:${PY_VER}-${DISTRO} -ENV PATH /usr/local/bin:$PATH -ENV PYTHON_VERSION 3.9.17 +# Avoid warnings by switching to noninteractive +ENV DEBIAN_FRONTEND=noninteractive + +USER root RUN \ - adduser --system --disabled-password --shell /bin/bash vscode && \ # install docker apt-get update && \ apt-get install ca-certificates curl gnupg lsb-release -y && \ @@ -31,27 +34,19 @@ COPY ./ /tmp/element-moseq/ RUN \ # pipeline dependencies - apt-get update && \ - apt-get install -y gcc ffmpeg graphviz && \ - pip install ipywidgets && \ - pip install --no-cache-dir -e /tmp/element-moseq[kpms,elements,tests] && \ + apt-get install gcc g++ ffmpeg libsm6 libxext6 -y && \ + pip install --no-cache-dir -e /tmp/element-moseq[elements,tests] && \ # clean up - rm -rf /tmp/element-moseq/ && \ + rm -rf /tmp/element-moseq && \ apt-get clean -# Install CPU version for KPMS -RUN pip install "jax[cpu]==0.3.22" -f https://storage.googleapis.com/jax-releases/jax_releases.html - ENV DJ_HOST fakeservices.datajoint.io ENV DJ_USER root ENV DJ_PASS simple -ENV DATA_MOUNTPOINT /workspaces/element-moseq/example_data -ENV KPMS_ROOT_DATA_DIR $DATA_MOUNTPOINT/inbox -ENV KPMS_PROCESSED_DATA_DIR $DATA_MOUNTPOINT/outbox +ENV KPMS_ROOT_DATA_DIR /workspaces/element-moseq/example_data ENV DATABASE_PREFIX neuro_ USER vscode -CMD bash -c "sudo rm /var/run/docker.pid; sudo dockerd" -ENV LD_LIBRARY_PATH="/lib:/opt/conda/lib" +CMD bash -c "sudo rm /var/run/docker.pid; sudo dockerd" diff --git a/CHANGELOG.md b/CHANGELOG.md index 29c076d8..9feb0a95 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,39 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention. +## [1.0.0] - 2025-09-10 + +> **BREAKING CHANGES** - This version contains breaking changes due to keypoint-moseq upgrade and API refactoring. Please review the changes below and update your code accordingly. + +### Breaking Changes ++ **BREAKING**: Upgrade keypoint-moseq from pinned 0.4.8 version to the latest version from source with breaking changes adding new features that are not compatible with the previous kpms versions ++ **BREAKING**: Rename `kpms_reader` functions to generate, load, and update kpms config files ++ **BREAKING**: Rename `PCAPrep` to `PreProcessing` table and add new attributes ++ **BREAKING**: Add feature to remove outlier keypoints in `PreProcessing` table with new `outlier_scale_factor` attribute in `PCATask`, only available in latest version of kpms ++ **BREAKING**: Add `sigmasq_loc` feature in `PreFit` and `FullFit` to automatically estimate sigmasq_loc (prior controlling the centroid movement across frames), only available in latest version of kpms + +### New Features and Fixes ++ Feat - Add support to load from both DLC `config.yml` and `config.yaml` file extensions ++ Feat - Add new `mosesq_report` schema with comprehensive reporting capabilities ++ Feat - Refactor `PreProcessing` table to use 3-part make function, add a new `Video` part table, and add new attributes `video_duration`, `frame_rate` and `average_frame_rate` to store these new computations ++ Feat - Move and refactor `viz_utils` into new `plotting` module ++ Feat - Update `model_name` varchar and folder naming ++ Feat - Migrate from `setup.py` to `pyproject.toml` and `conda_env.yml` for modern Python packaging standards ++ Fix - Update devcontainer to use Python 3.11 and upgrade dependencies ++ Fix - Remove JAX dependencies from `pyproject.toml` ++ Fix - Update pre-commit hooks for improved linting and consistency ++ Fix - Correct generation of `kpms_dj_config.yml` and refactor `moseq_train` and `moseq_infer` to use the renamed functions ++ Fix - Refactor `moseq_infer` and `moseq_train`, and implement a three-part make function in the most resource-intensive functions ++ Fix - Update folder naming logic to use string of combined primary attributes instead of datetime in `PreFit` and `FullFit` ++ Fix - Improve path and directory handling using `Path` objects and robust existence checks ++ Fix - Remove redundancy of variables in `PreProcessing` table ++ Fix - Update deprecated datetime usage ++ Fix - Fix filename generation in `Inference` table ++ Fix - Bugfix in `Model` imported foreign key ++ Fix - Update tutorial_pipeline ++ Add - Update docstrings across all modules ++ Add - Update pipeline images to reflect new architecture + ## [0.3.2] - 2025-08-25 + Feat - modernize packaging and environment management migrating from `setup.py` to `pyproject.toml`and `env.yml` + Fix - JAX compatibility issues diff --git a/env.yml b/conda_env.yml similarity index 100% rename from env.yml rename to conda_env.yml diff --git a/element_moseq/moseq_infer.py b/element_moseq/moseq_infer.py index 9e38515f..0919dc45 100644 --- a/element_moseq/moseq_infer.py +++ b/element_moseq/moseq_infer.py @@ -5,7 +5,6 @@ import importlib import inspect -import os from datetime import datetime, timezone from pathlib import Path @@ -14,7 +13,7 @@ from matplotlib import pyplot as plt from . import moseq_train -from .readers.kpms_reader import load_kpms_dj_config +from .readers import kpms_reader schema = dj.schema() _linking_module = None @@ -69,25 +68,23 @@ def activate( @schema class Model(dj.Manual): - """Register a model. + """Register a trained model. Attributes: model_id (int) : Unique ID for each model. model_name (varchar) : User-friendly model name. - model_dir (varchar) : Model directory relative to root data directory (e.g. `kpms_project_output_dir/2024_03_21-00_51_39`) - latent_dim (int) : Latent dimension of the model. - kappa (float) : Kappa value of the model. - model_desc (varchar) : Optional. User-defined description of the model + model_dir (varchar) : Model directory relative to root data directory. + model_desc (varchar) : Optional. User-defined description of the model. """ definition = """ - model_id : int # Unique ID for each model + model_id : int # Unique ID for each model --- - model_name : varchar(64) # User-friendly model name - model_dir : varchar(1000)# Model directory relative to root data directory - model_desc='' : varchar(1000)# Optional. User-defined description of the model - -> [nullable] moseq_train.SelectedFullFit + model_name : varchar(1000) # User-friendly model name + model_dir : varchar(1000) # Model directory relative to root data directory + model_desc='' : varchar(1000) # Optional. User-defined description of the model + -> [nullable] moseq_train.SelectedFullFit # Optional. FullFit key. """ @@ -136,6 +133,7 @@ class InferenceTask(dj.Manual): inference_output_dir (varchar) : Optional. Sub-directory where the results will be stored. inference_desc (varchar) : Optional. User-defined description of the inference task. num_iterations (int) : Optional. Number of iterations to use for the model inference. If null, the default number internally is 50. + task_mode (enum) : 'load': load computed analysis results, 'trigger': trigger computation """ definition = """ @@ -153,28 +151,32 @@ class InferenceTask(dj.Manual): @schema class Inference(dj.Computed): - """Infer the model from the checkpoint file and save the results as `results.h5` file. + """Infer the model from the checkpoint file and generate the results of segmenting continuous behavior into discrete syllables. Attributes: - InferenceTask (foreign_key) : `InferenceTask` key. - inference_duration (float) : Time duration (seconds) of the inference computation. + InferenceTask (foreign_key) : `InferenceTask` key. + syllable_segmentation_file (attach) : File path of the syllable analysis results (HDF5 format) containing syllable labels, latent states, centroids, and headings. + inference_duration (float) : Time duration (seconds) of the inference computation. """ definition = """ - -> InferenceTask # `InferenceTask` key + -> InferenceTask # `InferenceTask` key --- - inference_duration=NULL : float # Time duration (seconds) of the inference computation + syllable_segmentation_file : attach # File path of the syllable analysis results (HDF5 format) containing syllable labels, latent states, centroids, and headings + inference_duration=NULL : float # Time duration (seconds) of the inference computation """ class MotionSequence(dj.Part): """Store the results of the model inference. Attributes: + InferenceTask (foreign key) : `InferenceTask` key. video_name (varchar) : Name of the video. syllable (longblob) : Syllable labels (z). The syllable label assigned to each frame (i.e. the state indexes assigned by the model). latent_state (longblob) : Inferred low-dim pose state (x). Low-dimensional representation of the animal's pose in each frame. These are similar to PCA scores, are modified to reflect the pose dynamics and noise estimates inferred by the model. centroid (longblob) : Inferred centroid (v). The centroid of the animal in each frame, as estimated by the model. heading (longblob) : Inferred heading (h). The heading of the animal in each frame, as estimated by the model. + motion_sequence_file (attach) : File path of the temporal sequence of motion data (CSV format). """ definition = """ @@ -185,6 +187,7 @@ class MotionSequence(dj.Part): latent_state : longblob # Inferred low-dim pose state (x). Low-dimensional representation of the animal's pose in each frame. These are similar to PCA scores, are modified to reflect the pose dynamics and noise estimates inferred by the model centroid : longblob # Inferred centroid (v). The centroid of the animal in each frame, as estimated by the model heading : longblob # Inferred heading (h). The heading of the animal in each frame, as estimated by the model + motion_sequence_file: attach # File path of the temporal sequence of motion data (CSV format) """ class GridMoviesSampledInstances(dj.Part): @@ -202,12 +205,56 @@ class GridMoviesSampledInstances(dj.Part): instances: longblob # List of instances shown in each in grid movie (in row-major order), where each instance is specified as a tuple with the video name, start frame and end frame """ - def make(self, key): + def make_fetch(self, key): """ - This function is used to infer the model results from the checkpoint file and store the results in `MotionSequence` and `GridMoviesSampledInstances` tables. + Fetch data required for model inference. + """ + ( + keypointset_dir, + inference_output_dir, + num_iterations, + model_id, + pose_estimation_method, + task_mode, + ) = (InferenceTask & key).fetch1( + "keypointset_dir", + "inference_output_dir", + "num_iterations", + "model_id", + "pose_estimation_method", + "task_mode", + ) + + return ( + keypointset_dir, + inference_output_dir, + num_iterations, + model_id, + pose_estimation_method, + task_mode, + ) + + def make_compute( + self, + key, + keypointset_dir, + inference_output_dir, + num_iterations, + model_id, + pose_estimation_method, + task_mode, + ): + """ + Compute model inference results. Args: key (dict): `InferenceTask` primary key. + keypointset_dir (str): Directory containing keypoint data. + inference_output_dir (str): Output directory for inference results. + num_iterations (int): Number of iterations for model fitting. + model_id (int): Model ID. + pose_estimation_method (str): Pose estimation method. + task_mode (str): Task mode ('trigger' or 'load'). Raises: FileNotFoundError: If no pca model (`pca.p`) found in the parent model directory. @@ -215,48 +262,32 @@ def make(self, key): NotImplementedError: If the format method is not `deeplabcut`. FileNotFoundError: If no valid `kpms_dj_config` found in the parent model directory. - High-level Logic: - 1. Fetch the `inference_output_dir` where the results will be stored, and if it does not exist, create it. - 2. Fetch the `model_name` and the `num_iterations` from the `InferenceTask` table. - 3. Load the most recent model checkpoint and the pca model from files in the `kpms_project_output_dir`. - 4. Load the keypoint data for inference as `filepath_patterns` and format it. - 5. Initialize and apply the model with the new keypoint data. - 6. If the `num_iterations` is set, fit the model with the new keypoint data for `num_iterations` iterations; otherwise, fit the model with the default number of iterations (50). - 7. Save the results as a CSV file and store the histogram showing the frequency of each syllable. - 8. Generate and save the plots showing the median trajectory of poses associated with each given syllable. - 9. Generate and save video clips showing examples of each syllable. - 10. Generate and save the dendrogram representing distances between each syllable's median trajectory. - 11. Insert the inference duration in the `Inference` table. - 12. Insert the results in the `MotionSequence` and `GridMoviesSampledInstances` tables. + Returns: + tuple: Inference results including duration, results data, and sampled instances. """ from keypoint_moseq import ( apply_model, + filter_centroids_headings, format_data, generate_grid_movies, generate_trajectory_plots, + get_syllable_instances, load_checkpoint, load_keypoints, load_pca, + load_results, plot_similarity_dendrogram, plot_syllable_frequencies, + sample_instances, save_results_as_csv, ) - ( - keypointset_dir, - inference_output_dir, - num_iterations, - model_id, - pose_estimation_method, - task_mode, - ) = (InferenceTask & key).fetch1( - "keypointset_dir", - "inference_output_dir", - "num_iterations", - "model_id", - "pose_estimation_method", - "task_mode", - ) + # Constants used by default as in kpms + DEFAULT_NUM_ITERS = 50 + FILTER_SIZE = 9 + MIN_DURATION = 3 + MIN_FREQUENCY = 0.005 + GRID_SAMPLES = 4 * 6 # minimum rows * cols kpms_root = moseq_train.get_kpms_root_data_dir() kpms_processed = moseq_train.get_kpms_processed_data_dir() @@ -267,10 +298,10 @@ def make(self, key): ) keypointset_dir = find_full_path(kpms_root, keypointset_dir) - inference_output_dir = os.path.join(model_dir, inference_output_dir) + inference_output_dir = Path(model_dir) / inference_output_dir - if not os.path.exists(inference_output_dir): - os.makedirs(model_dir / inference_output_dir) + if not inference_output_dir.exists(): + inference_output_dir.mkdir(parents=True, exist_ok=True) pca_path = model_dir.parent / "pca.p" if pca_path: @@ -294,15 +325,8 @@ def make(self, key): coordinates, confidences, _ = load_keypoints( filepath_pattern=keypointset_dir, format=pose_estimation_method ) - else: - raise NotImplementedError( - "The currently supported format method is `deeplabcut`. If you require \ - support for another format method, please reach out to us at `support@datajoint.com`." - ) - kpms_dj_config = load_kpms_dj_config( - model_dir.parent.as_posix(), check_if_valid=True, build_indexes=True - ) + kpms_dj_config = kpms_reader.load_kpms_dj_config(model_dir.parent) if kpms_dj_config: data, metadata = format_data(coordinates, confidences, **kpms_dj_config) @@ -322,8 +346,8 @@ def make(self, key): model_name=Path(model_dir).name, results_path=(inference_output_dir / "results.h5").as_posix(), return_model=False, - num_iters=num_iterations - or 50, # default internal value in the keypoint-moseq function + num_iters=num_iterations or DEFAULT_NUM_ITERS, + overwrite=True, **kpms_dj_config, ) end_time = datetime.now(timezone.utc) @@ -363,50 +387,43 @@ def make(self, key): ) else: - from keypoint_moseq import ( - filter_centroids_headings, - get_syllable_instances, - load_results, - sample_instances, - ) # load results results = load_results( - project_dir=Path(inference_output_dir).parent, - model_name=Path(inference_output_dir).parts[-1], + project_dir=inference_output_dir.parent, + model_name=inference_output_dir.parts[-1], ) - # extract syllables from results - syllables = {k: v["syllable"] for k, v in results.items()} + # extract syllables from results + syllables = {k: v["syllable"] for k, v in results.items()} - # extract and smooth centroids and headings - centroids = {k: v["centroid"] for k, v in results.items()} - headings = {k: v["heading"] for k, v in results.items()} - - filter_size = 9 # default value - centroids, headings = filter_centroids_headings( - centroids, headings, filter_size=filter_size - ) + # extract and smooth centroids and headings + centroids = {k: v["centroid"] for k, v in results.items()} + headings = {k: v["heading"] for k, v in results.items()} - # extract sample instances for each syllable - syllable_instances = get_syllable_instances( - syllables, min_duration=3, min_frequency=0.005 - ) - # Map each syllable to a list of its sampled events. - sampled_instances = sample_instances( - syllable_instances=syllable_instances, - num_samples=4 * 6, # minimum rows * cols - coordinates=coordinates, - centroids=centroids, - headings=headings, - ) + centroids, headings = filter_centroids_headings( + centroids, headings, filter_size=FILTER_SIZE + ) - duration_seconds = None + # extract sample instances for each syllable + syllable_instances = get_syllable_instances( + syllables, min_duration=MIN_DURATION, min_frequency=MIN_FREQUENCY + ) + # Map each syllable to a list of its sampled events. + sampled_instances = sample_instances( + syllable_instances=syllable_instances, + num_samples=GRID_SAMPLES, + coordinates=coordinates, + centroids=centroids, + headings=headings, + ) - self.insert1({**key, "inference_duration": duration_seconds}) + duration_seconds = None + # Prepare motion sequence data + motion_sequence_data = [] for result_idx, result in results.items(): - self.MotionSequence.insert1( + motion_sequence_data.append( { **key, "video_name": result_idx, @@ -414,10 +431,49 @@ def make(self, key): "latent_state": result["latent_state"], "centroid": result["centroid"], "heading": result["heading"], + "motion_sequence_file": ( + inference_output_dir / "results_as_csv" / f"{result_idx}.csv" + ).as_posix(), } ) + # Prepare grid movie data + grid_movie_data = [] for syllable, sampled_instance in sampled_instances.items(): - self.GridMoviesSampledInstances.insert1( + grid_movie_data.append( {**key, "syllable": syllable, "instances": sampled_instance} ) + + return ( + duration_seconds, + motion_sequence_data, + grid_movie_data, + inference_output_dir, + ) + + def make_insert( + self, + key, + duration_seconds, + motion_sequence_data, + grid_movie_data, + inference_output_dir, + ): + """ + Insert inference results into the database. + """ + self.insert1( + { + **key, + "inference_duration": duration_seconds, + "syllable_segmentation_file": ( + inference_output_dir / "results.h5" + ).as_posix(), + } + ) + + for motion_record in motion_sequence_data: + self.MotionSequence.insert1(motion_record) + + for grid_record in grid_movie_data: + self.GridMoviesSampledInstances.insert1(grid_record) diff --git a/element_moseq/moseq_report.py b/element_moseq/moseq_report.py new file mode 100644 index 00000000..4b346b6c --- /dev/null +++ b/element_moseq/moseq_report.py @@ -0,0 +1,317 @@ +import importlib +import inspect +import os +import pathlib +import tempfile +from pathlib import Path + +import datajoint as dj +import matplotlib.pyplot as plt +import numpy as np +from element_interface.utils import find_full_path + +from . import moseq_infer, moseq_train +from .plotting import viz_utils +from .readers import kpms_reader + +schema = dj.schema() +_linking_module = None +logger = dj.logger + + +def activate( + report_schema_name: str, + *, + create_schema: bool = True, + create_tables: bool = True, + linking_module: str = None, +): + """Activate this schema. + + Args: + report_schema_name (str): Schema name on the database server to activate the `moseq_infer` schema. + create_schema (bool): When True (default), create schema in the database if it + does not yet exist. + create_tables (bool): When True (default), create schema tables in the database + if they do not yet exist. + linking_module (str): A module (or name) containing the required dependencies. + """ + + if isinstance(linking_module, str): + linking_module = importlib.import_module(linking_module) + assert inspect.ismodule( + linking_module + ), "The argument 'dependency' must be a module's name or a module" + assert hasattr( + linking_module, "get_kpms_root_data_dir" + ), "The linking module must specify a lookup function for a root data directory" + + global _linking_module + _linking_module = linking_module + + # activate + schema.activate( + report_schema_name, + create_schema=create_schema, + create_tables=create_tables, + add_objects=_linking_module.__dict__, + ) + + +# ----------------------------- Table declarations ---------------------- + + +@schema +class PreProcessingReport(dj.Imported): + """Store the outlier keypoints plots that are generated in outbox by `moseq_train.PreProcessing`""" + + definition = """ + -> moseq_train.PreProcessing + video_id: int # ID of the matching video file + --- + recording_name: varchar(255) # Name of the recording + outlier_plot: attach # A plot of the outlier keypoints + """ + + def make(self, key): + # Resolve project dir + project_rel = (moseq_train.PCATask & key).fetch1("kpms_project_output_dir") + kpms_project_output_dir = ( + Path(moseq_train.get_kpms_processed_data_dir()) / project_rel + ) + + # Fetch video table info + video_paths, video_ids = (moseq_train.KeypointSet.VideoFile & key).fetch( + "video_path", "video_id" + ) + + # Get recording names from PreProcessing dict keys + coords = (moseq_train.PreProcessing & key).fetch1("coordinates") + recording_names = list(coords.keys()) + + # Build mapping recording_name -> video_id + rec2vid = viz_utils.build_recording_to_video_id( + recording_names=recording_names, + video_paths=list(video_paths), + video_ids=list(video_ids), + ) + + # Insert one row per recording that matched a video_id + for rec in recording_names: + vid = rec2vid.get(rec) + if vid is None: + dj.logger.warning( + f"[PreProcessingReport] No video_id match for recording '{rec}'. Skipping." + ) + continue + + plot_path = ( + kpms_project_output_dir + / "quality_assurance" + / "plots" + / "keypoint_distance_outliers" + / f"{rec}.png" + ) + + if not plot_path.exists(): + dj.logger.warning( + f"[PreProcessingReport] Outlier plot not found at {plot_path}. Skipping." + ) + continue + + self.insert1( + { + **key, + "video_id": int(vid), + "recording_name": rec, + "outlier_plot": plot_path.as_posix(), + } + ) + + +@schema +class PCAReport(dj.Computed): + """ + Plots the principal components (PCs) from a PCAFit. + """ + + definition = """ + -> moseq_train.LatentDimension + --- + scree_plot: attach # A cumulative scree plot. + pcs_plot: attach # A visualization of each Principal Component (PC). + """ + + def make(self, key): + # Generate and store plots for the user to choose the latent dimensions in the next step + from keypoint_moseq import load_pca + + kpms_project_output_dir = (moseq_train.PCATask & key).fetch1( + "kpms_project_output_dir" + ) + kpms_project_output_dir = ( + moseq_train.get_kpms_processed_data_dir() / kpms_project_output_dir + ) + kpms_dj_config = kpms_reader.load_kpms_dj_config( + project_dir=kpms_project_output_dir + ) + + pca = load_pca(kpms_project_output_dir.as_posix()) + + # Modified version of plot_scree from keypoint_moseq + scree_fig = plt.figure() + num_pcs = len(pca.components_) + plt.plot(np.arange(num_pcs) + 1, np.cumsum(pca.explained_variance_ratio_)) + plt.xlabel("PCs") + plt.ylabel("Explained variance") + plt.gcf().set_size_inches((2.5, 2)) + plt.grid() + plt.tight_layout() + fname = f"{key['kpset_id']}_{key['bodyparts_id']}" + + # Modified version ofplot_pcs from keypoint_moseq to visualize components of PCs + pcs_fig = viz_utils.plot_pcs( + pca, + **kpms_dj_config, + interactive=False, + project_dir=kpms_project_output_dir, + ) + + tmpdir = tempfile.TemporaryDirectory() + + # plot variance summary + scree_path = pathlib.Path(tmpdir.name) / f"{fname}_scree_plot.png" + scree_fig.savefig(scree_path) + + # plot pcs + pcs_path = pathlib.Path(tmpdir.name) / f"{fname}_pcs_plot.png" + pcs_fig.savefig(pcs_path) + + # insert into table + self.insert1({**key, "scree_plot": scree_path, "pcs_plot": pcs_path}) + tmpdir.cleanup() + + +@schema +class PreFitReport(dj.Imported): + definition = """ + -> moseq_train.PreFit + --- + fitting_progress_pdf: attach # fitting_progress.pdf + """ + + def make(self, key): + prefit_model_name = (moseq_train.PreFit & key).fetch1("model_name") + prefit_model_dir = find_full_path( + moseq_train.get_kpms_processed_data_dir(), prefit_model_name + ) + prefit_output_dir = Path(prefit_model_dir) / "fitting_progress.pdf" + if prefit_output_dir.exists(): + self.insert1({**key, "fitting_progress_pdf": prefit_output_dir}) + else: + raise FileNotFoundError( + f"PreFit fitting_progress.pdf not found at {prefit_output_dir}" + ) + + +@schema +class FullFitReport(dj.Imported): + definition = """ + -> moseq_train.FullFit + --- + fitting_progress_pdf: attach # fitting_progress.pdf + """ + + def make(self, key): + fullfit_model_name = (moseq_train.FullFit & key).fetch1("model_name") + fullfit_model_dir = find_full_path( + moseq_train.get_kpms_processed_data_dir(), fullfit_model_name + ) + fullfit_output_file = Path(fullfit_model_dir) / "fitting_progress.pdf" + if fullfit_output_file.exists(): + self.insert1({**key, "fitting_progress_pdf": fullfit_output_file}) + else: + raise FileNotFoundError( + f"FullFit fitting_progress.pdf not found at {fullfit_output_file}" + ) + + +@schema +class InferenceReport(dj.Imported): + definition = """ + -> moseq_infer.Inference + --- + syllable_frequencies: attach + similarity_dendrogram_png: attach + similarity_dendrogram_pdf: attach + all_trajectories_gif: attach + all_trajectories_pdf: attach + """ + + class Trajectory(dj.Part): + definition = """ + -> master + syllable_id: int + --- + plot_gif: attach + plot_pdf: attach + grid_movie: attach + """ + + def make(self, key): + import imageio + + task_info = (moseq_infer.InferenceTask & key).fetch1() + model = (moseq_infer.Model & {"model_id": task_info["model_id"]}).fetch1() + + model_dir = find_full_path( + moseq_train.get_kpms_processed_data_dir(), model["model_dir"] + ) + output_dir = Path(model_dir) / task_info["inference_output_dir"] + + # Insert per-inference entry + self.insert1( + { + **key, + "syllable_frequencies": output_dir / "syllable_frequencies.png", + "similarity_dendrogram_png": output_dir / "similarity_dendrogram.png", + "similarity_dendrogram_pdf": output_dir / "similarity_dendrogram.pdf", + "all_trajectories_gif": output_dir + / "trajectory_plots" + / "all_trajectories.gif", + "all_trajectories_pdf": output_dir + / "trajectory_plots" + / "all_trajectories.pdf", + } + ) + + # Insert per-syllable visuals + for syllable in (moseq_infer.Inference.GridMoviesSampledInstances & key).fetch( + "syllable" + ): + video_mp4_path = output_dir / "grid_movies" / f"syllable{syllable}.mp4" + video_mp4_to_gif_path = ( + output_dir / "grid_movies" / f"syllable{syllable}_grid_movie.gif" + ) + reader = imageio.get_reader(video_mp4_path) + fps = reader.get_meta_data()["fps"] + writer = imageio.get_writer(video_mp4_to_gif_path, fps=fps, loop=0) + + for frame in reader: + writer.append_data(frame) + + writer.close() + + self.Trajectory.insert1( + { + **key, + "syllable_id": syllable, + "plot_gif": output_dir + / "trajectory_plots" + / f"syllable{syllable}.gif", + "plot_pdf": output_dir + / "trajectory_plots" + / f"syllable{syllable}.pdf", + "grid_movie": video_mp4_to_gif_path, + } + ) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index f905dbab..192a7e66 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -18,9 +18,7 @@ from .readers import kpms_reader schema = dj.schema() - _linking_module = None - logger = dj.logger @@ -35,24 +33,21 @@ def activate( Args: train_schema_name (str): A string containing the name of the `moseq_train` schema. - infer_schema_name (str): A string containing the name of the `moseq_infer` schema. - create_schema (bool): If True (default), schema will be created in the database. + create_schema (bool): If True (default), schema will be created in the database. create_tables (bool): If True (default), tables related to the schema will be created in the database. linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema. - Dependencies: Functions: - get_kpms_root_data_dir(): Returns absolute path for root data director(y/ies) - with all behavioral recordings, as (list of) string(s). + get_kpms_root_data_dir(): Returns absolute path for root data directory/ies + with all behavioral recordings, as (list of) string(s). get_kpms_processed_data_dir(): Optional. Returns absolute path for processed - data. - + data. """ if isinstance(linking_module, str): linking_module = importlib.import_module(linking_module) assert inspect.ismodule( linking_module - ), "The argument 'dependency' must be a module's name or a module" + ), "The argument 'dependency' must be a module's name or a module object" assert hasattr( linking_module, "get_kpms_root_data_dir" @@ -74,13 +69,14 @@ def activate( def get_kpms_root_data_dir() -> list: - """Pulls relevant func from parent namespace to specify root data dir(s). + """Fetches absolute data path to kpms data directories. - It is recommended that all paths in DataJoint Elements stored as relative - paths, with respect to some user-configured "root" director(y/ies). The - root(s) may vary between data modalities and user machines. Returns a full path - string or list of strings for possible root data directories. + The absolute path here is used as a reference for all downstream relative paths used in DataJoint. + + Returns: + A list of the absolute path(s) to kpms data directories. """ + root_directories = _linking_module.get_kpms_root_data_dir() if isinstance(root_directories, (str, Path)): root_directories = [root_directories] @@ -95,11 +91,10 @@ def get_kpms_root_data_dir() -> list: def get_kpms_processed_data_dir() -> Optional[str]: - """Pulls relevant func from parent namespace. Defaults to KPMS's project /videos/. + """Retrieve the root directory for all processed data. - Method in parent namespace should provide a string to a directory where KPMS output - files will be stored. If unspecified, output files will be stored in the - session directory 'videos' folder, per Keypoint-MoSeq default. + Returns: + A string for the full path to the root directory for processed data. """ if hasattr(_linking_module, "get_kpms_processed_data_dir"): return _linking_module.get_kpms_processed_data_dir() @@ -112,7 +107,7 @@ def get_kpms_processed_data_dir() -> Optional[str]: @schema class PoseEstimationMethod(dj.Lookup): - """Pose estimation methods supported by the keypoint loader of `keypoint-moseq` package. + """Name of the pose estimation method supported by the keypoint loader of `keypoint-moseq` package. Attributes: pose_estimation_method (str): Supported pose estimation method (deeplabcut, sleap, anipose, sleap-anipose, nwb, facemap) @@ -144,13 +139,13 @@ class KeypointSet(dj.Manual): kpset_id (int) : Unique ID for each keypoint set. PoseEstimationMethod (foreign key) : Unique format method used to obtain the keypoints data. kpset_dir (str) : Path where the keypoint files are located together with the pose estimation `config` file, relative to root data directory. - kpset_desc (str) : Optional. User-entered description. + kpset_desc (str) : Optional. User-entered description. """ definition = """ kpset_id : int # Unique ID for each keypoint set --- - -> PoseEstimationMethod # Unique format method used to obtain the keypoints data + -> PoseEstimationMethod # Unique format method used to obtain the keypoints data kpset_dir : varchar(255) # Path where the keypoint files are located together with the pose estimation `config` file, relative to root data directory kpset_desc='' : varchar(1000) # Optional. User-entered description """ @@ -166,9 +161,9 @@ class VideoFile(dj.Part): definition = """ -> master - video_id : int # Unique ID for each video corresponding to each keypoint data file, relative to root data directory + video_id : int # Unique ID for each video corresponding to each keypoint data file, relative to root data directory --- - video_path : varchar(1000) # Filepath of each video from which the keypoints are derived, relative to root data directory + video_path : varchar(1000) # Filepath of each video from which the keypoints are derived, relative to root data directory """ @@ -182,7 +177,7 @@ class Bodyparts(dj.Manual): anterior_bodyparts (blob) : List of strings of anterior bodyparts posterior_bodyparts (blob) : List of strings of posterior bodyparts use_bodyparts (blob) : List of strings of bodyparts to be used - bodyparts_desc(varchar) : Optional. User-entered description. + bodyparts_desc (varchar) : Optional. User-entered description. """ definition = """ @@ -199,69 +194,59 @@ class Bodyparts(dj.Manual): @schema class PCATask(dj.Manual): """ - Staging table to define the PCA task and its output directory. + Define the Principal Component Analysis (PCA) task for dimensionality reduction of keypoint data. Attributes: Bodyparts (foreign key) : Unique ID for each `Bodyparts` key - kpms_project_output_dir (str) : Keypoint-MoSeq project output directory, relative to root data directory + outlier_scale_factor (int) : Scale factor for outlier detection in keypoint data (default: 6) + kpms_project_output_dir (str) : Optional. Keypoint-MoSeq project output directory, relative to root data directory + task_mode (enum) : 'load' to load existing results, 'trigger' to compute new PCA """ definition = """ - -> Bodyparts # Unique ID for each `Bodyparts` key + -> Bodyparts # Unique ID for each `Bodyparts` key --- - kpms_project_output_dir='' : varchar(255) # Keypoint-MoSeq project output directory, relative to root data directory - task_mode='load' :enum('load','trigger') # Trigger or load the task + outlier_scale_factor=6 : int # Scale factor for outlier detection in keypoint data (default: 6) + kpms_project_output_dir='' : varchar(255) # Optional. Keypoint-MoSeq project output directory, relative to root data directory + task_mode='load' :enum('load','trigger') # 'load' to load existing results, 'trigger' to compute new PCA """ @schema -class PCAPrep(dj.Imported): +class PreProcessing(dj.Computed): """ - Table to set up the Keypoint-MoSeq project output directory (`kpms_project_output_dir`) , creating the default `config.yml` and updating it in a new `kpms_dj_config.yml`. + Preprocess keypoint data by cleaning outliers and setting up the Keypoint-MoSeq project configuration. Attributes: PCATask (foreign key) : Unique ID for each `PCATask` key. - coordinates (longblob) : Dictionary mapping filenames to keypoint coordinates as ndarrays of shape (n_frames, n_bodyparts, 2[or 3]). - confidences (longblob) : Dictionary mapping filenames to `likelihood` scores as ndarrays of shape (n_frames, n_bodyparts). + coordinates (longblob) : Dictionary mapping filenames to cleaned keypoint coordinates as ndarrays of shape (n_frames, n_bodyparts, 2[or 3]). + confidences (longblob) : Dictionary mapping filenames to updated likelihood scores as ndarrays of shape (n_frames, n_bodyparts). formatted_bodyparts (longblob) : List of bodypart names. The order of the names matches the order of the bodyparts in `coordinates` and `confidences`. - average_frame_rate (float) : Average frame rate of the videos for model training. - frame_rates (longblob) : List of the frame rates of the videos for model training. + average_frame_rate (float) : Average frame rate of the videos for model training (used for kappa calculation). """ definition = """ -> PCATask # Unique ID for each `PCATask` key --- coordinates : longblob # Dictionary mapping filenames to keypoint coordinates as ndarrays of shape (n_frames, n_bodyparts, 2[or 3]) - confidences : longblob # Dictionary mapping filenames to `likelihood` scores as ndarrays of shape (n_frames, n_bodyparts) + confidences : longblob # Dictionary mapping filenames to likelihood scores as ndarrays of shape (n_frames, n_bodyparts) formatted_bodyparts : longblob # List of bodypart names. The order of the names matches the order of the bodyparts in `coordinates` and `confidences`. - average_frame_rate : float # Average frame rate of the videos for model training - frame_rates : longblob # List of the frame rates of the videos for model training + average_frame_rate : float # Average frame rate of the videos for model training (used for kappa calculation). """ - def make(self, key): + class Video(dj.Part): + definition = """ + -> master + video_name: varchar(255) + --- + video_duration : int # Duration of each video in minutes + frame_rate : float # Frame rate of the video in frames per second (Hz) """ - Make function to: - 1. Generate and update the `kpms_dj_config.yml` with both the videoset directory and the bodyparts. - 2. Create the keypoint coordinates and confidences scores to format the data for the PCA fitting. - - Args: - key (dict): Primary key from the `PCATask` table. - Raises: - NotImplementedError: `pose_estimation_method` is only supported for `deeplabcut`. - - High-Level Logic: - 1. Fetches the bodyparts, format method, and the directories for the Keypoint-MoSeq project output, the keypoint set, and the video set. - 2. Set variables for each of the full path of the mentioned directories. - 3. Find the first existing pose estimation config file in the `kpset_dir` directory, if not found, raise an error. - 4. Check that the pose_estimation_method is `deeplabcut` and set up the project output directory with the default `config.yml`. - 5. Create the `kpms_project_output_dir` (if it does not exist), and generates the kpms default `config.yml` with the default values from the pose estimation config. - 6. Create a copy of the kpms `config.yml` named `kpms_dj_config.yml` that will be updated with both the `video_dir` and bodyparts - 7. Load keypoint data from the keypoint files found in the `kpset_dir` that will serve as the training set. - 8. As a result of the keypoint loading, the coordinates and confidences scores are generated and will be used to format the data for modeling. - 9. Calculate the average frame rate and the frame rate list of the videoset from which the keypoint set is derived. This two attributes can be used to calculate the kappa value. - 10. Insert the results of this `make` function into the table. + def make_fetch(self, key): + """ + Fetch required data for preprocessing from database tables. """ anterior_bodyparts, posterior_bodyparts, use_bodyparts = ( @@ -279,12 +264,78 @@ def make(self, key): "video_path", "video_id" ) - kpms_project_output_dir, task_mode = (PCATask & key).fetch1( - "kpms_project_output_dir", "task_mode" + kpms_project_output_dir, task_mode, outlier_scale_factor = ( + PCATask & key + ).fetch1("kpms_project_output_dir", "task_mode", "outlier_scale_factor") + + return ( + anterior_bodyparts, + posterior_bodyparts, + use_bodyparts, + pose_estimation_method, + kpset_dir, + video_paths, + video_ids, + kpms_project_output_dir, + task_mode, + outlier_scale_factor, + ) + + def make_compute( + self, + key, + anterior_bodyparts, + posterior_bodyparts, + use_bodyparts, + pose_estimation_method, + kpset_dir, + video_paths, + video_ids, + kpms_project_output_dir, + task_mode, + outlier_scale_factor, + ): + """ + Compute preprocessing steps including outlier removal and video metadata extraction. + + Args: + key (dict): Primary key from the `PCATask` table. + anterior_bodyparts (list): List of anterior bodyparts. + posterior_bodyparts (list): List of posterior bodyparts. + use_bodyparts (list): List of bodyparts to use. + pose_estimation_method (str): Pose estimation method (e.g., 'deeplabcut'). + kpset_dir (str): Keypoint set directory path. + video_paths (list): List of video file paths. + video_ids (list): List of video IDs. + kpms_project_output_dir (str): Project output directory path. + task_mode (str): Task mode ('load' or 'trigger'). + outlier_scale_factor (int): Scale factor for outlier detection. + + Returns: + tuple: Processed data including cleaned coordinates, confidences, and video metadata. + + Raises: + NotImplementedError: Only `deeplabcut` pose estimation method is supported. + FileNotFoundError: No DLC config file found in `kpset_dir`. + + High-Level Logic: + 1. Find the first existing pose estimation config file in the `kpset_dir` directory, if not found, raise an error. + 2. Check that the pose_estimation_method is `deeplabcut` and set up the project output directory with the default `config.yml`. + 3. Create the `kpms_project_output_dir` (if it does not exist), and generates the kpms default `config.yml` with the default values from the pose estimation config. + 4. Create a copy of the kpms `config.yml` named `kpms_dj_config.yml` that will be updated with both the `video_dir` and bodyparts + 5. Load keypoint data from the keypoint files found in the `kpset_dir` that will serve as the training set. + 6. Detect and remove outlier keypoints using medoid distance analysis, then interpolate missing values. + 7. Calculate the average frame rate and the frame rate list of the videoset from which the keypoint set is derived. These two attributes can be used to calculate the kappa value. + """ + from keypoint_moseq import ( + find_medoid_distance_outliers, + interpolate_keypoints, + load_keypoints, + plot_medoid_distance_outliers, ) if task_mode == "trigger": - from keypoint_moseq import load_config, setup_project + from keypoint_moseq import setup_project try: kpms_project_output_dir = find_full_path( @@ -293,7 +344,7 @@ def make(self, key): except FileNotFoundError: kpms_project_output_dir = ( - get_kpms_processed_data_dir() / kpms_project_output_dir + Path(get_kpms_processed_data_dir()) / kpms_project_output_dir ) kpset_dir = find_full_path(get_kpms_root_data_dir(), kpset_dir) @@ -302,35 +353,25 @@ def make(self, key): ) if pose_estimation_method == "deeplabcut": + from .readers.kpms_reader import _base_config_path + + cfg_path = _base_config_path(kpset_dir) + cfg = Path(cfg_path) + if not cfg.exists(): + raise FileNotFoundError( + f"No DLC config.(yml|yaml) found in {kpset_dir}" + ) setup_project( project_dir=kpms_project_output_dir.as_posix(), - deeplabcut_config=(kpset_dir / "config.yaml") - or (kpset_dir / "config.yml"), + deeplabcut_config=cfg.as_posix(), ) + else: raise NotImplementedError( "Currently, `deeplabcut` is the only pose estimation method supported by this Element. Please reach out at `support@datajoint.com` if you use another method." ) - kpms_config = load_config( - kpms_project_output_dir.as_posix(), - check_if_valid=True, - build_indexes=False, - ) - - kpms_dj_config_kwargs_dict = dict( - video_dir=videos_dir.as_posix(), - anterior_bodyparts=anterior_bodyparts, - posterior_bodyparts=posterior_bodyparts, - use_bodyparts=use_bodyparts, - ) - kpms_config.update(**kpms_dj_config_kwargs_dict) - kpms_reader.generate_kpms_dj_config( - kpms_project_output_dir.as_posix(), **kpms_config - ) else: - from keypoint_moseq import load_keypoints - kpms_project_output_dir = find_full_path( get_kpms_processed_data_dir(), kpms_project_output_dir ) @@ -339,60 +380,165 @@ def make(self, key): get_kpms_root_data_dir(), Path(video_paths[0]).parent ) - coordinates, confidences, formatted_bodyparts = load_keypoints( + raw_coordinates, raw_confidences, formatted_bodyparts = load_keypoints( filepath_pattern=kpset_dir, format=pose_estimation_method ) - frame_rate_list = [] - for fp, _ in zip(video_paths, video_ids): + video_metadata_list = [] + frame_rates = [] + for fp, video_id in zip(video_paths, video_ids): video_path = (find_full_path(get_kpms_root_data_dir(), fp)).as_posix() cap = cv2.VideoCapture(video_path) - frame_rate_list.append(int(cap.get(cv2.CAP_PROP_FPS))) + fps = float(cap.get(cv2.CAP_PROP_FPS)) + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) cap.release() - average_frame_rate = int(np.mean(frame_rate_list)) + duration_minutes = (frame_count / fps) / 60.0 + frame_rates.append(fps) + + # Get video name for the Video part table + video_key = {"kpset_id": key["kpset_id"], "video_id": video_id} + if KeypointSet.VideoFile & video_key: + video_record = (KeypointSet.VideoFile & video_key).fetch1() + video_name = Path( + video_record["video_path"] + ).stem # Get filename without extension + video_metadata_list.append( + { + "video_id": video_id, + "video_name": video_name, + "video_duration": int(duration_minutes), + "frame_rate": fps, + } + ) + else: + logger.warning(f"Video record not found for video_id {video_id}") + + average_frame_rate = float(np.mean(frame_rates)) + + # Generate a copy of config.yml with the generated/updated info after it is known + kpms_reader.dj_generate_config( + project_dir=kpms_project_output_dir, + video_dir=str(videos_dir), + use_bodyparts=list(use_bodyparts), + anterior_bodyparts=list(anterior_bodyparts), + posterior_bodyparts=list(posterior_bodyparts), + outlier_scale_factor=float(outlier_scale_factor), + ) + kpms_reader.update_kpms_dj_config( + kpms_project_output_dir, + fps=average_frame_rate, + ) + + # Remove outlier keypoints + kpms_config = kpms_reader.load_kpms_dj_config(kpms_project_output_dir) + cleaned_coordinates = {} + cleaned_confidences = {} + + for recording_name in raw_coordinates: + raw_coords = raw_coordinates[recording_name].copy() + raw_conf = raw_confidences[recording_name].copy() + + # Find outliers using medoid distance analysis + outliers = find_medoid_distance_outliers( + raw_coords, outlier_scale_factor=outlier_scale_factor + ) + + # Interpolate keypoints to fix outliers + cleaned_coords = interpolate_keypoints(raw_coords, outliers["mask"]) + + # Update confidences for outlier points + cleaned_conf = np.where(outliers["mask"], 0, raw_conf) + + cleaned_coordinates[recording_name] = cleaned_coords + cleaned_confidences[recording_name] = cleaned_conf + + # Plot outliers + if formatted_bodyparts is not None: + try: + plot_medoid_distance_outliers( + project_dir=kpms_project_output_dir.as_posix(), + recording_name=recording_name, + original_coordinates=raw_coords, + interpolated_coordinates=cleaned_coords, + outlier_mask=outliers["mask"], + outlier_thresholds=outliers["thresholds"], + **kpms_config, + ) + + except Exception as e: + logger.warning( + f"Could not create outlier plot for {recording_name}: {e}" + ) + + return ( + cleaned_coordinates, + cleaned_confidences, + formatted_bodyparts, + average_frame_rate, + video_metadata_list, + ) + + def make_insert( + self, + key, + cleaned_coordinates, + cleaned_confidences, + formatted_bodyparts, + average_frame_rate, + video_metadata_list, + ): + """ + Insert processed data into the PreProcessing table and Video part table. + """ self.insert1( dict( **key, - coordinates=coordinates, - confidences=confidences, + coordinates=cleaned_coordinates, + confidences=cleaned_confidences, formatted_bodyparts=formatted_bodyparts, average_frame_rate=average_frame_rate, - frame_rates=frame_rate_list, ) ) + for video_metadata in video_metadata_list: + self.Video.insert1( + dict( + **key, + video_name=video_metadata["video_name"], + video_duration=video_metadata["video_duration"], + frame_rate=video_metadata["frame_rate"], + ) + ) + @schema class PCAFit(dj.Computed): - """Fit PCA model. + """Fit Principal Component Analysis (PCA) model for dimensionality reduction of keypoint data. Attributes: - PCAPrep (foreign key) : `PCAPrep` Key. + PreProcessing (foreign key) : `PreProcessing` Key. pca_fit_time (datetime) : datetime of the PCA fitting analysis. """ definition = """ - -> PCAPrep # `PCAPrep` Key + -> PreProcessing # `PreProcessing` Key --- pca_fit_time=NULL : datetime # datetime of the PCA fitting analysis """ def make(self, key): """ - Make function to format the keypoint data, fit the PCA model, and store it as a `pca.p` file in the Keypoint-MoSeq project output directory. + Format keypoint data and fit PCA model for dimensionality reduction. Args: - key (dict): `PCAPrep` Key - - Raises: + key (dict): `PreProcessing` Key High-Level Logic: - 1. Fetch the `kpms_project_output_dir` from the `PCATask` table and define its full path. - 2. Load the `kpms_dj_config` file that contains the updated `video_dir` and bodyparts, \ - and format the keypoint data with the coordinates and confidences scores to be used in the PCA fitting. - 3. Fit the PCA model and save it as `pca.p` file in the output directory. - 4.Insert the creation datetime as the `pca_fit_time` into the table. + 1. Fetch project output directory and load configuration. + 2. Format keypoint data with coordinates and confidences. + 3. Fit PCA model and save as `pca.p` file. + 4. Insert creation datetime into table. """ from keypoint_moseq import fit_pca, format_data, save_pca @@ -400,13 +546,13 @@ def make(self, key): "kpms_project_output_dir", "task_mode" ) kpms_project_output_dir = ( - get_kpms_processed_data_dir() / kpms_project_output_dir + Path(get_kpms_processed_data_dir()) / kpms_project_output_dir ) - kpms_default_config = kpms_reader.load_kpms_dj_config( - kpms_project_output_dir.as_posix(), check_if_valid=True, build_indexes=True + kpms_default_config = kpms_reader.load_kpms_dj_config(kpms_project_output_dir) + coordinates, confidences = (PreProcessing & key).fetch1( + "coordinates", "confidences" ) - coordinates, confidences = (PCAPrep & key).fetch1("coordinates", "confidences") data, _ = format_data( **kpms_default_config, coordinates=coordinates, confidences=confidences ) @@ -424,10 +570,7 @@ def make(self, key): @schema class LatentDimension(dj.Imported): """ - Determine the latent dimension as part of the autoregressive hyperparameters (`ar_hypparams`) for the model fitting. - The objective of the analysis is to inform the user about the number of principal components needed to explain a - 90% variance threshold. Subsequently, the decision on how many components to utilize for the model fitting is left - to the user. + Determine the optimal latent dimension for model fitting based on variance explained by PCA components. Attributes: PCAFit (foreign key) : `PCAFit` Key. @@ -446,55 +589,52 @@ class LatentDimension(dj.Imported): def make(self, key): """ - Make function to compute and store the latent dimension that explains a 90% variance threshold. + Compute and store the optimal latent dimension based on 90% variance threshold. Args: key (dict): `PCAFit` Key. Raises: + FileNotFoundError: No PCA model found in project directory. High-Level Logic: - 1. Fetches the Keypoint-MoSeq project output directory from the PCATask table and define the full path. - 2. Load the PCA model from file in this directory. - 2. Set a specified variance threshold to 90% and compute the cumulative sum of the explained variance ratio. - 3. Determine the number of components required to explain the specified variance. - 3.1 If the cumulative sum of the explained variance ratio is less than the specified variance threshold, \ - it sets the `latent_dimension` to the total number of components and `variance_percentage` to the cumulative sum of the explained variance ratio. - 3.2 If the cumulative sum of the explained variance ratio is greater than the specified variance threshold, \ - it sets the `latent_dimension` to the number of components that explain the specified variance and `variance_percentage` to the specified variance threshold. - 4. Insert the results of this `make` function into the table. + 1. Fetch project output directory and load PCA model. + 2. Calculate cumulative explained variance ratio. + 3. Determine number of components needed for 90% variance. + 4. Insert results into table. """ + + VARIANCE_THRESHOLD = 0.90 + from keypoint_moseq import load_pca kpms_project_output_dir = (PCATask & key).fetch1("kpms_project_output_dir") kpms_project_output_dir = ( - get_kpms_processed_data_dir() / kpms_project_output_dir + Path(get_kpms_processed_data_dir()) / kpms_project_output_dir ) pca_path = kpms_project_output_dir / "pca.p" - if pca_path: + if pca_path.exists(): pca = load_pca(kpms_project_output_dir.as_posix()) else: raise FileNotFoundError( f"No pca model (`pca.p`) found in the project directory {kpms_project_output_dir}" ) - variance_threshold = 0.90 - cs = np.cumsum( pca.explained_variance_ratio_ ) # explained_variance_ratio_ndarray of shape (n_components,) - if cs[-1] < variance_threshold: + if cs[-1] < VARIANCE_THRESHOLD: latent_dimension = len(cs) variance_percentage = cs[-1] * 100 latent_dim_desc = ( f"All components together only explain {cs[-1]*100}% of variance." ) else: - latent_dimension = (cs > variance_threshold).nonzero()[0].min() + 1 - variance_percentage = variance_threshold * 100 - latent_dim_desc = f">={variance_threshold*100}% of variance explained by {(cs>variance_threshold).nonzero()[0].min()+1} components." + latent_dimension = (cs > VARIANCE_THRESHOLD).nonzero()[0].min() + 1 + variance_percentage = VARIANCE_THRESHOLD * 100 + latent_dim_desc = f">={VARIANCE_THRESHOLD*100}% of variance explained by {(cs>VARIANCE_THRESHOLD).nonzero()[0].min()+1} components." self.insert1( dict( @@ -508,66 +648,63 @@ def make(self, key): @schema class PreFitTask(dj.Manual): - """Insert the parameters for the model (AR-HMM) pre-fitting. + """Define parameters for Stage 1: Auto-Regressive Hidden Markov Model (AR-HMM) pre-fitting. Attributes: PCAFit (foreign key) : `PCAFit` task. pre_latent_dim (int) : Latent dimension to use for the model pre-fitting. - pre_kappa (int) : Kappa value to use for the model pre-fitting. - pre_num_iterations (int) : Number of Gibbs sampling iterations to run in the model pre-fitting. + pre_kappa (int) : Kappa value to use for the model pre-fitting (controls syllable duration). + pre_num_iterations (int) : Number of Gibbs sampling iterations to run in the model pre-fitting (typically 10-50). + model_name (varchar) : Name of the model to be loaded if `task_mode='load'` + task_mode (enum) : 'load': load computed analysis results, 'trigger': trigger computation pre_fit_desc(varchar) : User-defined description of the pre-fitting task. """ definition = """ - -> PCAFit # `PCAFit` Key - pre_latent_dim : int # Latent dimension to use for the model pre-fitting - pre_kappa : int # Kappa value to use for the model pre-fitting - pre_num_iterations : int # Number of Gibbs sampling iterations to run in the model pre-fitting + -> PCAFit # `PCAFit` Key + pre_latent_dim : int # Latent dimension to use for the model pre-fitting. + pre_kappa : int # Kappa value to use for the model pre-fitting (controls syllable duration). + pre_num_iterations : int # Number of Gibbs sampling iterations to run in the model pre-fitting (typically 10-50). --- - model_name : varchar(100) # Name of the model to be loaded if `task_mode='load'` - task_mode='load' :enum('trigger','load')# 'load': load computed analysis results, 'trigger': trigger computation - pre_fit_desc='' : varchar(1000) # User-defined description of the pre-fitting task + model_name='' : varchar(1000) # Name of the model to be loaded if `task_mode='load'` + task_mode='load' :enum('load','trigger') # 'load': load computed analysis results, 'trigger': trigger computation + pre_fit_desc='' : varchar(1000) # User-defined description of the pre-fitting task """ @schema class PreFit(dj.Computed): - """Fit AR-HMM model. + """Fit Auto-Regressive Hidden Markov Model (AR-HMM) for initial behavioral syllable discovery. Attributes: PreFitTask (foreign key) : `PreFitTask` Key. - model_name (varchar) : Name of the model as "kpms_project_output_dir/model_name". + model_name (varchar) : Name of the model as "model_name". pre_fit_duration (float) : Time duration (seconds) of the model fitting computation. """ definition = """ - -> PreFitTask # `PreFitTask` Key + -> PreFitTask # `PreFitTask` Key --- - model_name='' : varchar(100) # Name of the model as "kpms_project_output_dir/model_name" - pre_fit_duration=NULL : float # Time duration (seconds) of the model fitting computation + model_name='' : varchar(1000) # Name of the model as "kpms_project_output_dir/model_name" + pre_fit_duration=NULL : float # Time duration (seconds) of the model fitting computation """ def make(self, key): """ - Make function to fit the AR-HMM model using the latent trajectory defined by `model['states']['x']. + Fit AR-HMM model for initial behavioral syllable discovery. Args: - key (dict) : dictionary with the `PreFitTask` Key. + key (dict): Dictionary with the `PreFitTask` Key. Raises: + FileNotFoundError: No PCA model found in project directory. - High-level Logic: - 1. Fetch the `kpms_project_output_dir` and define the full path. - 2. Fetch the model parameters from the `PreFitTask` table. - 3. Update the `dj_config.yml` with the latent dimension and kappa for the AR-HMM fitting. - 4. Load the pca model from file in the `kpms_project_output_dir`. - 5. Fetch `coordinates` and `confidences` scores to format the data for the model initialization. \ - # Data - contains the data for model fitting. \ - # Metadata - contains the recordings and start/end frames for the data. - 6. Initialize the model that create a `model` dict containing states, parameters, hyperparameters, noise prior, and random seed. - 7. Update the model dict with the selected kappa for the AR-HMM fitting. - 8. Fit the AR-HMM model using the `pre_num_iterations` and create a subdirectory in `kpms_project_output_dir` with the model's latest checkpoint file. - 9. Calculate the duration of the model fitting computation and insert it in the `PreFit` table. + High-Level Logic: + 1. Fetch project output directory and model parameters. + 2. Update configuration with latent dimension and kappa values. + 3. Load PCA model and format keypoint data. + 4. Initialize and fit AR-HMM model. + 5. Calculate fitting duration and insert results. """ from keypoint_moseq import ( fit_model, @@ -592,31 +729,43 @@ def make(self, key): "model_name", ) if task_mode == "trigger": - kpms_dj_config = kpms_reader.load_kpms_dj_config( - kpms_project_output_dir.as_posix(), - check_if_valid=True, - build_indexes=True, - ) + from keypoint_moseq import estimate_sigmasq_loc - kpms_dj_config.update( - dict(latent_dim=int(pre_latent_dim), kappa=float(pre_kappa)) - ) - kpms_reader.generate_kpms_dj_config( - kpms_project_output_dir.as_posix(), **kpms_dj_config + # Update the existing kpms_dj_config.yml with new latent_dim and kappa values + kpms_reader.update_kpms_dj_config( + kpms_project_output_dir, + latent_dim=int(pre_latent_dim), + kappa=float(pre_kappa), ) + # Load the updated config for use in model fitting + kpms_dj_config = kpms_reader.load_kpms_dj_config(kpms_project_output_dir) + pca_path = kpms_project_output_dir / "pca.p" - if pca_path: + if pca_path.exists(): pca = load_pca(kpms_project_output_dir.as_posix()) else: raise FileNotFoundError( f"No pca model (`pca.p`) found in the project directory {kpms_project_output_dir}" ) - coordinates, confidences = (PCAPrep & key).fetch1( + coordinates, confidences = (PreProcessing & key).fetch1( "coordinates", "confidences" ) - data, metadata = format_data(coordinates, confidences, **kpms_dj_config) + data, metadata = format_data( + coordinates=coordinates, confidences=confidences, **kpms_dj_config + ) + + kpms_reader.update_kpms_dj_config( + kpms_project_output_dir, + sigmasq_loc=estimate_sigmasq_loc( + data["Y"], data["mask"], filter_size=int(kpms_dj_config["fps"]) + ), + ) + + kpms_dj_config = kpms_reader.load_kpms_dj_config( + project_dir=kpms_project_output_dir + ) model = init_model(data=data, metadata=metadata, pca=pca, **kpms_dj_config) @@ -624,14 +773,19 @@ def make(self, key): model, kappa=float(pre_kappa), latent_dim=int(pre_latent_dim) ) + model_name_str = f"latent_dim_{int(pre_latent_dim)}_kappa_{float(pre_kappa)}_iters_{int(pre_num_iterations)}" + start_time = datetime.now(timezone.utc) model, model_name = fit_model( model=model, + model_name=model_name_str, data=data, metadata=metadata, project_dir=kpms_project_output_dir.as_posix(), ar_only=True, num_iters=pre_num_iterations, + generate_progress_plots=True, # saved to {project_dir}/{model_name}/plots/ + save_every_n_iters=25, ) end_time = datetime.now(timezone.utc) @@ -653,25 +807,25 @@ def make(self, key): @schema class FullFitTask(dj.Manual): - """Insert the parameters for the full (Keypoint-SLDS model) fitting. - The full model will generally require a lower value of kappa to yield the same target syllable durations. + """Define parameters for FullFit step of model fitting. Attributes: PCAFit (foreign key) : `PCAFit` Key. full_latent_dim (int) : Latent dimension to use for the model full fitting. - full_kappa (int) : Kappa value to use for the model full fitting. - full_num_iterations (int) : Number of Gibbs sampling iterations to run in the model full fitting. + full_kappa (int) : Kappa value to use for the model full fitting (typically lower than pre-fit kappa). + full_num_iterations (int) : Number of Gibbs sampling iterations to run in the model full fitting (typically 200-500). + model_name (varchar) : Name of the model to be loaded if `task_mode='load'` + task_mode (enum) : 'load': load computed analysis results, 'trigger': trigger computation full_fit_desc(varchar) : User-defined description of the model full fitting task. - """ definition = """ -> PCAFit # `PCAFit` Key full_latent_dim : int # Latent dimension to use for the model full fitting - full_kappa : int # Kappa value to use for the model full fitting - full_num_iterations : int # Number of Gibbs sampling iterations to run in the model full fitting + full_kappa : int # Kappa value to use for the model full fitting (typically lower than pre-fit kappa). + full_num_iterations : int # Number of Gibbs sampling iterations to run in the model full fitting (typically 200-500). --- - model_name : varchar(100) # Name of the model to be loaded if `task_mode='load'` + model_name='' : varchar(1000) # Name of the model to be loaded if `task_mode='load'` task_mode='load' :enum('load','trigger')# Trigger or load the task full_fit_desc='' : varchar(1000) # User-defined description of the model full fitting task """ @@ -679,7 +833,7 @@ class FullFitTask(dj.Manual): @schema class FullFit(dj.Computed): - """Fit the full (Keypoint-SLDS) model. + """Fit the complete Keypoint Switching Linear Dynamical System (Keypoint-SLDS) model. Attributes: FullFitTask (foreign key) : `FullFitTask` Key. @@ -688,37 +842,32 @@ class FullFit(dj.Computed): """ definition = """ - -> FullFitTask # `FullFitTask` Key + -> FullFitTask # `FullFitTask` Key --- - model_name : varchar(100) # Name of the model as "kpms_project_output_dir/model_name" - full_fit_duration=NULL : float # Time duration (seconds) of the full fitting computation + model_name='' : varchar(1000) # Name of the model as "kpms_project_output_dir/model_name" + full_fit_duration=NULL : float # Time duration (seconds) of the full fitting computation """ def make(self, key): """ - Make function to fit the full (keypoint-SLDS) model - - Args: - key (dict): dictionary with the `FullFitTask` Key. - - Raises: - - High-level Logic: - 1. Fetch the `kpms_project_output_dir` and define the full path. - 2. Fetch the model parameters from the `FullFitTask` table. - 2. Update the `dj_config.yml` with the selected latent dimension and kappa for the full-fitting. - 3. Initialize and fit the full model in a new `model_name` directory. - 4. Load the pca model from file in the `kpms_project_output_dir`. - 5. Fetch the `coordinates` and `confidences` scores to format the data for the model initialization. - 6. Initialize the model that create a `model` dict containing states, parameters, hyperparameters, noise prior, and random seed. - 7. Update the model dict with the selected kappa for the Keypoint-SLDS fitting. - 8. Fit the Keypoint-SLDS model using the `full_num_iterations` and create a subdirectory in `kpms_project_output_dir` with the model's latest checkpoint file. - 8. Reindex syllable labels by their frequency in the most recent model snapshot in the checkpoint file. \ - This function permutes the states and parameters of a saved checkpoint so that syllables are labeled \ - in order of frequency (i.e. so that 0 is the most frequent, 1 is the second most, and so on). - 8. Calculate the duration of the model fitting computation and insert it in the `PreFit` table. + Fit the complete Keypoint-SLDS model with spatial and temporal dynamics. + + Args: + key (dict): Dictionary with the `FullFitTask` Key. + + Raises: + FileNotFoundError: No PCA model found in project directory. + + High-Level Logic: + 1. Fetch project output directory and model parameters. + 2. Update configuration with latent dimension and kappa values. + 3. Load PCA model and format keypoint data. + 4. Initialize and fit Keypoint-SLDS model. + 5. Reindex syllable labels by frequency. + 6. Calculate fitting duration and insert results. """ from keypoint_moseq import ( + estimate_sigmasq_loc, fit_model, format_data, init_model, @@ -742,38 +891,52 @@ def make(self, key): "model_name", ) if task_mode == "trigger": - kpms_dj_config = kpms_reader.load_kpms_dj_config( - kpms_project_output_dir.as_posix(), - check_if_valid=True, - build_indexes=True, - ) - kpms_dj_config.update( - dict(latent_dim=int(full_latent_dim), kappa=float(full_kappa)) + kpms_reader.update_kpms_dj_config( + project_dir=kpms_project_output_dir, + latent_dim=int(full_latent_dim), + kappa=float(full_kappa), ) - kpms_reader.generate_kpms_dj_config( - kpms_project_output_dir.as_posix(), **kpms_dj_config + + kpms_dj_config = kpms_reader.load_kpms_dj_config( + project_dir=kpms_project_output_dir ) pca_path = kpms_project_output_dir / "pca.p" - if pca_path: + if pca_path.exists(): pca = load_pca(kpms_project_output_dir.as_posix()) else: raise FileNotFoundError( f"No pca model (`pca.p`) found in the project directory {kpms_project_output_dir}" ) - coordinates, confidences = (PCAPrep & key).fetch1( + coordinates, confidences = (PreProcessing & key).fetch1( "coordinates", "confidences" ) - data, metadata = format_data(coordinates, confidences, **kpms_dj_config) + data, metadata = format_data( + coordinates=coordinates, confidences=confidences, **kpms_dj_config + ) + kpms_reader.update_kpms_dj_config( + project_dir=kpms_project_output_dir, + sigmasq_loc=estimate_sigmasq_loc( + data["Y"], data["mask"], filter_size=int(kpms_dj_config["fps"]) + ), + ) + + kpms_dj_config = kpms_reader.load_kpms_dj_config( + project_dir=kpms_project_output_dir + ) + model = init_model(data=data, metadata=metadata, pca=pca, **kpms_dj_config) model = update_hypparams( model, kappa=float(full_kappa), latent_dim=int(full_latent_dim) ) + model_name_str = f"latent_dim_{int(full_latent_dim)}_kappa_{float(full_kappa)}_iters_{int(full_num_iterations)}" + start_time = datetime.now(timezone.utc) model, model_name = fit_model( model=model, + model_name=model_name_str, data=data, metadata=metadata, project_dir=kpms_project_output_dir.as_posix(), @@ -784,8 +947,10 @@ def make(self, key): duration_seconds = (end_time - start_time).total_seconds() reindex_syllables_in_checkpoint( - kpms_project_output_dir.as_posix(), Path(model_name).parts[-1] + project_dir=kpms_project_output_dir.as_posix(), + model_name=Path(model_name).parts[-1], ) + else: duration_seconds = None @@ -803,11 +968,17 @@ def make(self, key): @schema class SelectedFullFit(dj.Manual): - """Selected FullFit model for use in the inference pipeline""" + """Register selected FullFit models for use in the inference pipeline. + + Attributes: + FullFit (foreign key) : `FullFit` Key. + registered_model_name (varchar): User-friendly model name + registered_model_desc (varchar): Optional user-defined description + """ definition = """ -> FullFit --- - registered_model_name : varchar(64) # User-friendly model name + registered_model_name : varchar(1000) # User-friendly model name registered_model_desc='' : varchar(1000) # Optional user-defined description """ diff --git a/element_moseq/plotting/__init__.py b/element_moseq/plotting/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/element_moseq/plotting/viz_utils.py b/element_moseq/plotting/viz_utils.py new file mode 100644 index 00000000..14e35882 --- /dev/null +++ b/element_moseq/plotting/viz_utils.py @@ -0,0 +1,330 @@ +# ---- Modified version of the viz functions from the main branch of keypoint_moseq ---- + +import os +import re +from difflib import SequenceMatcher +from pathlib import Path +from textwrap import fill +from typing import Dict, List, Optional, Tuple + +import matplotlib.pyplot as plt +import numpy as np +from jax_moseq.models.keypoint_slds import center_embedding +from keypoint_moseq.util import get_distance_to_medoid, get_edges, plot_keypoint_traces +from keypoint_moseq.viz import plot_pcs_3D + +_DLC_SUFFIX_RE = re.compile( + r"(?:DLC_[A-Za-z0-9]+[A-Za-z]+(?:\d+)?(?:[A-Za-z]+)?" # scorer-ish token + r"(?:\w+)*)" # optional extra blobs + r"(?:shuffle\d+)?" # shuffleN + r"(?:_\d+)?$" # _iter +) + + +def _normalize_name(name: str) -> str: + """ + Normalize a recording/video string for matching: + - lowercase, strip whitespace + - drop extension if present + - remove common DLC suffix blob (e.g., '...DLC_resnet50_...shuffle1_500000') + - collapse separators to single spaces + """ + s = name.lower().strip() + s = Path(s).stem + s = _DLC_SUFFIX_RE.sub("", s) # strip DLC tail if present + s = re.sub(r"[\s._-]+", " ", s).strip() + return s + + +def build_recording_to_video_id( + recording_names: List[str], + video_paths: List[str], + video_ids: List[int], + fuzzy_threshold: float = 0.80, +) -> Dict[str, Optional[int]]: + """ + Returns: {recording_name -> video_id or None if no good match} + Strategy: exact normalized stem match; if none, substring; then fuzzy. + """ + # candidate stems from videos + stems: List[Tuple[str, int]] = [ + (_normalize_name(Path(p).name), vid) for p, vid in zip(video_paths, video_ids) + ] + + mapping: Dict[str, Optional[int]] = {} + + for rec in recording_names: + nrec = _normalize_name(rec) + + # 1) exact normalized match + exact = [vid for stem, vid in stems if stem == nrec] + if exact: + mapping[rec] = exact[0] + continue + + # 2) substring either way (choose longest stem to disambiguate) + subs = [(stem, vid) for stem, vid in stems if nrec in stem or stem in nrec] + if subs: + subs.sort(key=lambda x: len(x[0]), reverse=True) + mapping[rec] = subs[0][1] + continue + + # 3) fuzzy best match + best_vid, best_ratio = None, 0.0 + for stem, vid in stems: + r = SequenceMatcher(None, nrec, stem).ratio() + if r > best_ratio: + best_ratio, best_vid = r, vid + mapping[rec] = best_vid if best_ratio >= fuzzy_threshold else None + + return mapping + + +def plot_medoid_distance_outliers( + project_dir: str, + recording_name: str, + original_coordinates: np.ndarray, + interpolated_coordinates: np.ndarray, + outlier_mask, + outlier_thresholds, + bodyparts: list[str], + **kwargs, +): + """Create and save a plot comparing distance-to-medoid for original vs. interpolated keypoints. + + Generates a multi-panel plot showing the distance from each keypoint to the medoid + position for both original and interpolated coordinates. The plot includes threshold + lines and shaded regions for outlier frames. Saves the figure to the QA plots + directory. + + Parameters + ------- + project_dir: str + Path to the project directory where the plot will be saved. + + recording_name: str + Name of the recording, used for the plot title and filename. + + original_coordinates: ndarray of shape (n_frames, n_keypoints, keypoint_dim) + Original keypoint coordinates before interpolation. + + interpolated_coordinates: ndarray of shape (n_frames, n_keypoints, keypoint_dim) + Keypoint coordinates after interpolation. + + outlier_mask: ndarray of shape (n_frames, n_keypoints) + Boolean mask indicating outlier keypoints (True = outlier). + + outlier_thresholds: ndarray of shape (n_keypoints,) + Distance thresholds for each keypoint above which points are considered outliers. + + bodyparts: list of str + Names of bodyparts corresponding to each keypoint. Must have length equal to + n_keypoints. + + **kwargs + Additional keyword arguments (ignored), usually overflow from **config(). + + Returns + ------- + None + The plot is saved to 'QA/plots/keypoint_distance_outliers/{recording_name}.png'. + """ + + plot_path = os.path.join( + project_dir, + "quality_assurance", + "plots", + "keypoint_distance_outliers", + f"{recording_name}.png", + ) + os.makedirs(os.path.dirname(plot_path), exist_ok=True) + + original_distances = get_distance_to_medoid( + original_coordinates + ) # (n_frames, n_keypoints) + interpolated_distances = get_distance_to_medoid( + interpolated_coordinates + ) # (n_frames, n_keypoints) + + fig = plot_keypoint_traces( + traces=[original_distances, interpolated_distances], + plot_title=recording_name, + bodyparts=bodyparts, + line_labels=["Original", "Interpolated"], + thresholds=outlier_thresholds, + shading_mask=outlier_mask, + ) + + fig.savefig(plot_path, dpi=300) + + plt.close() + print(f"Saved keypoint distance outlier plot for {recording_name} to {plot_path}.") + return fig + + +def plot_pcs( + pca, + *, + use_bodyparts, + skeleton, + keypoint_colormap="autumn", + keypoint_colors=None, + savefig=True, + project_dir=None, + scale=1, + plot_n_pcs=10, + axis_size=(2, 1.5), + ncols=5, + node_size=30.0, + line_width=2.0, + interactive=True, + **kwargs, +): + """ + Visualize the components of a fitted PCA model. + + For each PC, a subplot shows the mean pose (semi-transparent) along with a + perturbation of the mean pose in the direction of the PC. + + Parameters + ---------- + pca : :py:func:`sklearn.decomposition.PCA` + Fitted PCA model + + use_bodyparts : list of str + List of bodyparts to that are used in the model; used to index bodypart + names in the skeleton. + + skeleton : list + List of edges that define the skeleton, where each edge is a pair of + bodypart names. + + keypoint_colormap : str + Name of a matplotlib colormap to use for coloring the keypoints. + + keypoint_colors : array-like, shape=(num_keypoints,3), default=None + Color for each keypoint. If None, `keypoint_colormap` is used. If the + dtype is int, the values are assumed to be in the range 0-255, + otherwise they are assumed to be in the range 0-1. + + savefig : bool, True + Whether to save the figure to a file. If true, the figure is saved to + `{project_dir}/pcs-{xy/xz/yz}.pdf` (`xz` and `yz` are only included + for 3D data). + + project_dir : str, default=None + Path to the project directory. Required if `savefig` is True. + + scale : float, default=0.5 + Scale factor for the perturbation of the mean pose. + + plot_n_pcs : int, default=10 + Number of PCs to plot. + + axis_size : tuple of float, default=(2,1.5) + Size of each subplot in inches. + + ncols : int, default=5 + Number of columns in the figure. + + node_size : float, default=30.0 + Size of the keypoints in the figure. + + line_width: float, default=2.0 + Width of edges in skeleton + + interactive : bool, default=True + For 3D data, whether to generate an interactive 3D plot. + """ + k = len(use_bodyparts) + d = len(pca.mean_) // (k - 1) + + if keypoint_colors is None: + cmap = plt.cm.get_cmap(keypoint_colormap) + keypoint_colors = cmap(np.linspace(0, 1, k)) + + Gamma = np.array(center_embedding(k)) + edges = get_edges(use_bodyparts, skeleton) + plot_n_pcs = min(plot_n_pcs, pca.components_.shape[0]) + + magnitude = np.sqrt((pca.mean_**2).mean()) * scale + ymean = Gamma @ pca.mean_.reshape(k - 1, d) + ypcs = (pca.mean_ + magnitude * pca.components_).reshape(-1, k - 1, d) + ypcs = Gamma[np.newaxis] @ ypcs[:plot_n_pcs] + + if d == 2: + dims_list, names = [[0, 1]], ["xy"] + if d == 3: + dims_list, names = [[0, 1], [0, 2]], ["xy", "xz"] + + for dims, name in zip(dims_list, names): + nrows = int(np.ceil(plot_n_pcs / ncols)) + fig, axs = plt.subplots(nrows, ncols, sharex=True, sharey=True) + for i, ax in enumerate(axs.flat): + if i >= plot_n_pcs: + ax.axis("off") + continue + + for e in edges: + ax.plot( + *ymean[:, dims][e].T, + color=keypoint_colors[e[0]], + zorder=0, + alpha=0.25, + linewidth=line_width, + ) + ax.plot( + *ypcs[i][:, dims][e].T, + color="k", + zorder=2, + linewidth=line_width + 0.2, + ) + ax.plot( + *ypcs[i][:, dims][e].T, + color=keypoint_colors[e[0]], + zorder=3, + linewidth=line_width, + ) + + ax.scatter( + *ymean[:, dims].T, + c=keypoint_colors, + s=node_size, + zorder=1, + alpha=0.25, + linewidth=0, + ) + ax.scatter( + *ypcs[i][:, dims].T, + c=keypoint_colors, + s=node_size, + zorder=4, + edgecolor="k", + linewidth=0.2, + ) + + ax.set_title(f"PC {i+1}", fontsize=10) + ax.set_aspect("equal") + ax.axis("off") + + fig.set_size_inches((axis_size[0] * ncols, axis_size[1] * nrows)) + plt.tight_layout() + + if savefig: + assert project_dir is not None, fill( + "The `savefig` option requires a `project_dir`" + ) + plt.savefig(os.path.join(project_dir, f"pcs-{name}.pdf")) + plt.show() + + if interactive and d == 3: + plot_pcs_3D( + ymean, + ypcs, + edges, + keypoint_colormap, + project_dir if savefig else None, + node_size / 3, + line_width * 2, + ) + return fig diff --git a/element_moseq/readers/kpms_reader.py b/element_moseq/readers/kpms_reader.py index 4664fdcf..59b8bc4e 100644 --- a/element_moseq/readers/kpms_reader.py +++ b/element_moseq/readers/kpms_reader.py @@ -1,5 +1,7 @@ import logging import os +from pathlib import Path +from typing import Any, Dict, Union import jax.numpy as jnp import yaml @@ -7,181 +9,142 @@ logger = logging.getLogger("datajoint") -def generate_kpms_dj_config(output_dir, **kwargs): - """This function mirrors the behavior of the `generate_config` function from the `keypoint_moseq` - package. Nonetheless, it produces a duplicate of the initial configuration file, titled - `kpms_dj_config.yml`, in the output directory to maintain the integrity of the original file. - This replicated file accommodates any customized project settings, with default configurations - utilized unless specified differently via keyword arguments. +DJ_CONFIG = "kpms_dj_config.yml" +BASE_CONFIG = "config.yml" - Args: - output_dir (str): Directory containing the `kpms_dj_config.yml` that will be generated. - kwargs (dict): Custom project settings. - """ - def _build_yaml(sections, comments): - text_blocks = [] - for title, data in sections: - centered_title = f" {title} ".center(50, "=") - text_blocks.append(f"\n\n{'#'}{centered_title}{'#'}") - for key, value in data.items(): - text = yaml.dump({key: value}).strip("\n") - if key in comments: - text = f"\n{'#'} {comments[key]}\n{text}" - text_blocks.append(text) - return "\n".join(text_blocks) - - def _update_dict(new, original): - return {k: new[k] if k in new else v for k, v in original.items()} - - hypperams = _update_dict( - kwargs, - { - "error_estimator": {"slope": -0.5, "intercept": 0.25}, - "obs_hypparams": { - "sigmasq_0": 0.1, - "sigmasq_C": 0.1, - "nu_sigma": 1e5, - "nu_s": 5, - }, - "ar_hypparams": { - "latent_dim": 10, - "nlags": 3, - "S_0_scale": 0.01, - "K_0_scale": 10.0, - }, - "trans_hypparams": { - "num_states": 100, - "gamma": 1e3, - "alpha": 5.7, - "kappa": 1e6, - }, - "cen_hypparams": {"sigmasq_loc": 0.5}, - }, - ) - - hypperams = {k: _update_dict(kwargs, v) for k, v in hypperams.items()} - - anatomy = _update_dict( - kwargs, - { - "bodyparts": ["BODYPART1", "BODYPART2", "BODYPART3"], - "use_bodyparts": ["BODYPART1", "BODYPART2", "BODYPART3"], - "skeleton": [ - ["BODYPART1", "BODYPART2"], - ["BODYPART2", "BODYPART3"], - ], - "anterior_bodyparts": ["BODYPART1"], - "posterior_bodyparts": ["BODYPART3"], - }, - ) - - other = _update_dict( - kwargs, - { - "recording_name_suffix": "", - "verbose": False, - "conf_pseudocount": 1e-3, - "video_dir": "", - "keypoint_colormap": "autumn", - "whiten": True, - "fix_heading": False, - "seg_length": 10000, - }, - ) - - fitting = _update_dict( - kwargs, - { - "added_noise_level": 0.1, - "PCA_fitting_num_frames": 1000000, - "conf_threshold": 0.5, - # 'kappa_scan_target_duration': 12, - # 'kappa_scan_min': 1e2, - # 'kappa_scan_max': 1e12, - # 'num_arhmm_scan_iters': 50, - # 'num_arhmm_final_iters': 200, - # 'num_kpslds_scan_iters': 50, - # 'num_kpslds_final_iters': 500 - }, - ) - - comments = { - "verbose": "whether to print progress messages during fitting", - "keypoint_colormap": "colormap used for visualization; see `matplotlib.cm.get_cmap` for options", - "added_noise_level": "upper bound of uniform noise added to the data during initial AR-HMM fitting; this is used to regularize the model", - "PCA_fitting_num_frames": "number of frames used to fit the PCA model during initialization", - "video_dir": "directory with videos from which keypoints were derived (used for crowd movies)", - "recording_name_suffix": "suffix used to match videos to recording names; this can usually be left empty (see `util.find_matching_videos` for details)", - "bodyparts": "used to access columns in the keypoint data", - "skeleton": "used for visualization only", - "use_bodyparts": "determines the subset of bodyparts to use for modeling and the order in which they are represented", - "anterior_bodyparts": "used to initialize heading", - "posterior_bodyparts": "used to initialize heading", - "seg_length": "data are broken up into segments to parallelize fitting", - "trans_hypparams": "transition hyperparameters", - "ar_hypparams": "autoregressive hyperparameters", - "obs_hypparams": "keypoint observation hyperparameters", - "cen_hypparams": "centroid movement hyperparameters", - "error_estimator": "parameters to convert neural net likelihoods to error size priors", - "save_every_n_iters": "frequency for saving model snapshots during fitting; if 0 only final state is saved", - "kappa_scan_target_duration": "target median syllable duration (in frames) for choosing kappa", - "whiten": "whether to whiten principal components; used to initialize the latent pose trajectory `x`", - "conf_threshold": "used to define outliers for interpolation when the model is initialized", - "conf_pseudocount": "pseudocount used regularize neural network confidences", - "fix_heading": "whether to keep the heading angle fixed; this should only be True if the pose is constrained to a narrow range of angles, e.g. a headfixed mouse.", - } - - sections = [ - ("ANATOMY", anatomy), - ("FITTING", fitting), - ("HYPER PARAMS", hypperams), - ("OTHER", other), - ] - - with open(os.path.join(output_dir, "kpms_dj_config.yml"), "w") as f: - f.write(_build_yaml(sections, comments)) - - -def load_kpms_dj_config(output_dir, check_if_valid=True, build_indexes=True): - """ - This function mirrors the functionality of the `load_config` function from the `keypoint_moseq` - package. Similarly, this function loads the `kpms_dj_config.yml` from the output directory. +def _dj_config_path(project_dir: Union[str, os.PathLike]) -> str: + return str(Path(project_dir) / DJ_CONFIG) + - Args: - output_dir (str): Directory containing the `kpms_dj_config.yml` that will be loaded. - check_if_valid (bool): default=True. Check if the config is valid using :py:func:`keypoint_moseq.io.check_config_validity` - build_indexes (bool): default=True. Add keys `"anterior_idxs"` and `"posterior_idxs"` to the config. Each maps to a jax array indexing the elements of `config["anterior_bodyparts"]` and `config["posterior_bodyparts"]` by their order in `config["use_bodyparts"]` +def _base_config_path(project_dir: Union[str, os.PathLike]) -> str: + """Return the path to the base config file, checking for both .yml and .yaml extensions.""" + project_path = Path(project_dir) + # Check for config.yml first (preferred) + config_yml = project_path / "config.yml" + if config_yml.exists(): + return str(config_yml) + # Fall back to config.yaml + config_yaml = project_path / "config.yaml" + if config_yaml.exists(): + return str(config_yaml) + # If neither exists, return the default (config.yml) + return str(config_yml) - Returns: - kpms_dj_config (dict): configuration settings + +def _check_config_validity(config: Dict[str, Any]) -> bool: + """ + Minimal mirror of keypoint_moseq.io.check_config_validity logic that matters + for anatomy consistency (anterior/posterior must be subset of use_bodyparts). """ + errors = [] + for bp in config.get("anterior_bodyparts", []): + if bp not in config.get("use_bodyparts", []): + errors.append( + f"ACTION REQUIRED: `anterior_bodyparts` contains {bp} " + "which is not one of the options in `use_bodyparts`." + ) + for bp in config.get("posterior_bodyparts", []): + if bp not in config.get("use_bodyparts", []): + errors.append( + f"ACTION REQUIRED: `posterior_bodyparts` contains {bp} " + "which is not one of the options in `use_bodyparts`." + ) + if errors: + for e in errors: + print(e) + return False + return True + + +def dj_generate_config(project_dir: str, **kwargs) -> str: + """ + Generate or refresh `/kpms_dj_config.yml`. - from keypoint_moseq import check_config_validity + Behavior: + - If the DJ config doesn't exist, start from the **base** `/config.yml` + created by upstream `setup_project`, then overlay kwargs and write DJ config. + - If the DJ config exists, load it, overlay kwargs, and rewrite it. - config_path = os.path.join(output_dir, "kpms_dj_config.yml") + Returns the path to `kpms_dj_config.yml`. + """ + project_dir = str(project_dir) + base_cfg_path = _base_config_path(project_dir) + dj_cfg_path = _dj_config_path(project_dir) + + if os.path.exists(dj_cfg_path): + with open(dj_cfg_path, "r") as f: + cfg = yaml.safe_load(f) or {} + else: + if not os.path.exists(base_cfg_path): + raise FileNotFoundError( + f"Missing base config at {base_cfg_path}. Run upstream setup_project first. " + f"Expected either config.yml or config.yaml in {project_dir}." + ) + with open(base_cfg_path, "r") as f: + cfg = yaml.safe_load(f) or {} + cfg.update(kwargs) + + if "skeleton" not in cfg or cfg["skeleton"] is None: + cfg["skeleton"] = [] + + with open(dj_cfg_path, "w") as f: + yaml.safe_dump(cfg, f, sort_keys=False) + return dj_cfg_path + + +def load_kpms_dj_config( + project_dir: str, check_if_valid: bool = True, build_indexes: bool = True +) -> Dict[str, Any]: + """ + Load `/kpms_dj_config.yml`. - with open(config_path, "r") as f: - kpms_dj_config = yaml.safe_load(f) + Mirrors keypoint_moseq.io.load_config behavior: + - check_if_valid -> anatomy subset checks + - build_indexes -> adds jax arrays 'anterior_idxs' and 'posterior_idxs' + indexing into 'use_bodyparts' by order. + """ + dj_cfg_path = _dj_config_path(project_dir) + if not os.path.exists(dj_cfg_path): + raise FileNotFoundError( + f"Missing DJ config at {dj_cfg_path}. Create it with dj_generate_config()." + ) + + with open(dj_cfg_path, "r") as f: + cfg = yaml.safe_load(f) or {} if check_if_valid: - check_config_validity(kpms_dj_config) + _check_config_validity(cfg) if build_indexes: - kpms_dj_config["anterior_idxs"] = jnp.array( - [ - kpms_dj_config["use_bodyparts"].index(bp) - for bp in kpms_dj_config["anterior_bodyparts"] - ] - ) - kpms_dj_config["posterior_idxs"] = jnp.array( - [ - kpms_dj_config["use_bodyparts"].index(bp) - for bp in kpms_dj_config["posterior_bodyparts"] - ] + anterior = cfg.get("anterior_bodyparts", []) + posterior = cfg.get("posterior_bodyparts", []) + use_bps = cfg.get("use_bodyparts", []) + cfg["anterior_idxs"] = jnp.array([use_bps.index(bp) for bp in anterior]) + cfg["posterior_idxs"] = jnp.array([use_bps.index(bp) for bp in posterior]) + + if "skeleton" not in cfg or cfg["skeleton"] is None: + cfg["skeleton"] = [] + + return cfg + + +def update_kpms_dj_config(project_dir: str, **kwargs) -> Dict[str, Any]: + """ + Update `kpms_dj_config.yml` with provided top-level kwargs (same pattern as + keypoint_moseq.io.update_config), then rewrite the file and return the dict. + """ + dj_cfg_path = _dj_config_path(project_dir) + if not os.path.exists(dj_cfg_path): + raise FileNotFoundError( + f"Missing DJ config at {dj_cfg_path}. Create it with dj_generate_config()." ) - if "skeleton" not in kpms_dj_config or kpms_dj_config["skeleton"] is None: - kpms_dj_config["skeleton"] = [] + with open(dj_cfg_path, "r") as f: + cfg = yaml.safe_load(f) or {} + + cfg.update(kwargs) - return kpms_dj_config + with open(dj_cfg_path, "w") as f: + yaml.safe_dump(cfg, f, sort_keys=False) + return cfg diff --git a/element_moseq/version.py b/element_moseq/version.py index 931952fd..0ec000f6 100644 --- a/element_moseq/version.py +++ b/element_moseq/version.py @@ -2,4 +2,4 @@ Package metadata """ -__version__ = "0.3.2" +__version__ = "1.0.0" diff --git a/images/pipeline.svg b/images/pipeline.svg index 0e1b1929..e1096ef0 100644 --- a/images/pipeline.svg +++ b/images/pipeline.svg @@ -1,327 +1,425 @@ - - - - + + + + -moseq_train.KeypointSet - - -moseq_train.KeypointSet +moseq_train.LatentDimension + + +moseq_train.LatentDimension - - -moseq_train.KeypointSet.VideoFile - - -moseq_train.KeypointSet.VideoFile + + +moseq_report.PCAReport + + +moseq_report.PCAReport - + -moseq_train.KeypointSet->moseq_train.KeypointSet.VideoFile - +moseq_train.LatentDimension->moseq_report.PCAReport + - - -moseq_train.Bodyparts - - -moseq_train.Bodyparts + + +moseq_infer.Inference.GridMoviesSampledInstances + + +moseq_infer.Inference.GridMoviesSampledInstances - - -moseq_train.KeypointSet->moseq_train.Bodyparts - + + +moseq_infer.VideoRecording.File + + +moseq_infer.VideoRecording.File + - - -moseq_train.FullFitTask - - -moseq_train.FullFitTask + + + +Device + + +Device - - -moseq_train.FullFit - - -moseq_train.FullFit + + +moseq_infer.VideoRecording + + +moseq_infer.VideoRecording - - -moseq_train.FullFitTask->moseq_train.FullFit - + + +Device->moseq_infer.VideoRecording + - - -moseq_train.PreFit - - -moseq_train.PreFit + + +moseq_train.PreProcessing + + +moseq_train.PreProcessing + + +moseq_report.PreProcessingReport + + +moseq_report.PreProcessingReport + + + + + +moseq_train.PreProcessing->moseq_report.PreProcessingReport + + - + moseq_train.PCAFit - - -moseq_train.PCAFit + + +moseq_train.PCAFit - + -moseq_train.PCAFit->moseq_train.FullFitTask - - - - -moseq_train.LatentDimension - - -moseq_train.LatentDimension +moseq_train.PreProcessing->moseq_train.PCAFit + + + + +moseq_train.PreProcessing.Video + + +moseq_train.PreProcessing.Video - + -moseq_train.PCAFit->moseq_train.LatentDimension - +moseq_train.PreProcessing->moseq_train.PreProcessing.Video + - + moseq_train.PreFitTask - - -moseq_train.PreFitTask + + +moseq_train.PreFitTask - - -moseq_train.PCAFit->moseq_train.PreFitTask - - - - -moseq_infer.Model - - -moseq_infer.Model + + +moseq_train.PreFit + + +moseq_train.PreFit - + + +moseq_train.PreFitTask->moseq_train.PreFit + + + -moseq_infer.InferenceTask - - -moseq_infer.InferenceTask +moseq_infer.Inference + + +moseq_infer.Inference - + -moseq_infer.Model->moseq_infer.InferenceTask - +moseq_infer.Inference->moseq_infer.Inference.GridMoviesSampledInstances + - - -moseq_infer.Inference.GridMoviesSampledInstances - - -moseq_infer.Inference.GridMoviesSampledInstances + + +moseq_infer.Inference.MotionSequence + + +moseq_infer.Inference.MotionSequence - - -moseq_infer.Inference - - -moseq_infer.Inference + + +moseq_infer.Inference->moseq_infer.Inference.MotionSequence + + + + +moseq_report.InferenceReport + + +moseq_report.InferenceReport - - -moseq_infer.InferenceTask->moseq_infer.Inference - + + +moseq_infer.Inference->moseq_report.InferenceReport + - + +moseq_report.FullFitReport + + +moseq_report.FullFitReport + + + + + moseq_train.SelectedFullFit - - -moseq_train.SelectedFullFit + + +moseq_train.SelectedFullFit + + + + + +moseq_infer.Model + + +moseq_infer.Model - + moseq_train.SelectedFullFit->moseq_infer.Model - + - + -moseq_train.PoseEstimationMethod - - -moseq_train.PoseEstimationMethod +subject.Subject + + +subject.Subject - - -moseq_train.PoseEstimationMethod->moseq_train.KeypointSet - + + +session.Session + + +session.Session + - + + -moseq_train.PoseEstimationMethod->moseq_infer.InferenceTask - +subject.Subject->session.Session + - + + +moseq_report.InferenceReport.Trajectory + + +moseq_report.InferenceReport.Trajectory + + + + + +moseq_infer.InferenceTask + + +moseq_infer.InferenceTask + + + + -moseq_train.PreFitTask->moseq_train.PreFit - +moseq_infer.Model->moseq_infer.InferenceTask + - + -Device - - -Device +moseq_train.KeypointSet.VideoFile + + +moseq_train.KeypointSet.VideoFile - - -moseq_infer.VideoRecording - - -moseq_infer.VideoRecording + + +moseq_train.PoseEstimationMethod + + +moseq_train.PoseEstimationMethod - + -Device->moseq_infer.VideoRecording - +moseq_train.PoseEstimationMethod->moseq_infer.InferenceTask + - - -moseq_infer.VideoRecording.File - - -moseq_infer.VideoRecording.File + + +moseq_train.KeypointSet + + +moseq_train.KeypointSet - - -moseq_train.PCATask - - -moseq_train.PCATask - + + +moseq_train.PoseEstimationMethod->moseq_train.KeypointSet + + + +moseq_train.FullFitTask + + +moseq_train.FullFitTask + - - -moseq_train.Bodyparts->moseq_train.PCATask - - - -moseq_train.PCAPrep - - -moseq_train.PCAPrep + + +moseq_train.FullFit + + +moseq_train.FullFit - + -moseq_train.PCAPrep->moseq_train.PCAFit - +moseq_train.FullFitTask->moseq_train.FullFit + - + -moseq_infer.Inference->moseq_infer.Inference.GridMoviesSampledInstances - +moseq_report.InferenceReport->moseq_report.InferenceReport.Trajectory + - - -moseq_infer.Inference.MotionSequence - - -moseq_infer.Inference.MotionSequence - + + +moseq_infer.InferenceTask->moseq_infer.Inference + + + +moseq_report.PreFitReport + + +moseq_report.PreFitReport + - - -moseq_infer.Inference->moseq_infer.Inference.MotionSequence - - + -moseq_infer.VideoRecording->moseq_infer.InferenceTask - +moseq_train.PreFit->moseq_report.PreFitReport + - + -moseq_infer.VideoRecording->moseq_infer.VideoRecording.File - +moseq_train.PCAFit->moseq_train.LatentDimension + - + -moseq_train.FullFit->moseq_train.SelectedFullFit - +moseq_train.PCAFit->moseq_train.PreFitTask + - - -subject.Subject - - -subject.Subject + + +moseq_train.PCAFit->moseq_train.FullFitTask + + + + +moseq_infer.VideoRecording->moseq_infer.VideoRecording.File + + + + +moseq_infer.VideoRecording->moseq_infer.InferenceTask + + + + +session.Session->moseq_infer.VideoRecording + + + + +moseq_train.KeypointSet->moseq_train.KeypointSet.VideoFile + + + + +moseq_train.Bodyparts + + +moseq_train.Bodyparts - - -session.Session - - -session.Session + + +moseq_train.KeypointSet->moseq_train.Bodyparts + + + + +moseq_train.PCATask + + +moseq_train.PCATask - - -subject.Subject->session.Session - + + +moseq_train.Bodyparts->moseq_train.PCATask + - - -moseq_train.PCATask->moseq_train.PCAPrep - + + +moseq_train.PCATask->moseq_train.PreProcessing + - - -session.Session->moseq_infer.VideoRecording - + + +moseq_train.FullFit->moseq_report.FullFitReport + + + + +moseq_train.FullFit->moseq_train.SelectedFullFit + \ No newline at end of file diff --git a/images/pipeline_moseq_infer.svg b/images/pipeline_moseq_infer.svg index cd89e7a3..8c9050f9 100644 --- a/images/pipeline_moseq_infer.svg +++ b/images/pipeline_moseq_infer.svg @@ -1,98 +1,98 @@ - - - - + + + + -moseq_infer.VideoRecording.File - - -moseq_infer.VideoRecording.File +moseq_infer.InferenceTask + + +moseq_infer.InferenceTask - - -moseq_infer.Model - - -moseq_infer.Model + + +moseq_infer.Inference + + +moseq_infer.Inference - - -moseq_infer.InferenceTask - - -moseq_infer.InferenceTask - + + +moseq_infer.InferenceTask->moseq_infer.Inference + + + +moseq_infer.Inference.GridMoviesSampledInstances + + +moseq_infer.Inference.GridMoviesSampledInstances + - - -moseq_infer.Model->moseq_infer.InferenceTask - - + -moseq_infer.Inference - - -moseq_infer.Inference +moseq_infer.VideoRecording.File + + +moseq_infer.VideoRecording.File - - -moseq_infer.Inference.GridMoviesSampledInstances - - -moseq_infer.Inference.GridMoviesSampledInstances + + +moseq_infer.Model + + +moseq_infer.Model - + +moseq_infer.Model->moseq_infer.InferenceTask + + + + moseq_infer.Inference->moseq_infer.Inference.GridMoviesSampledInstances - + - + moseq_infer.Inference.MotionSequence - - -moseq_infer.Inference.MotionSequence + + +moseq_infer.Inference.MotionSequence - + moseq_infer.Inference->moseq_infer.Inference.MotionSequence - + - + moseq_infer.VideoRecording - - -moseq_infer.VideoRecording + + +moseq_infer.VideoRecording - - -moseq_infer.VideoRecording->moseq_infer.VideoRecording.File - - moseq_infer.VideoRecording->moseq_infer.InferenceTask - + - + -moseq_infer.InferenceTask->moseq_infer.Inference - +moseq_infer.VideoRecording->moseq_infer.VideoRecording.File + \ No newline at end of file diff --git a/images/pipeline_moseq_train.svg b/images/pipeline_moseq_train.svg index 7be6103d..eeb20963 100644 --- a/images/pipeline_moseq_train.svg +++ b/images/pipeline_moseq_train.svg @@ -1,182 +1,196 @@ - - - - + + + + -moseq_train.KeypointSet.VideoFile - - -moseq_train.KeypointSet.VideoFile +moseq_train.LatentDimension + + +moseq_train.LatentDimension - + -moseq_train.PoseEstimationMethod - - -moseq_train.PoseEstimationMethod +moseq_train.PreFit + + +moseq_train.PreFit - - -moseq_train.KeypointSet - - -moseq_train.KeypointSet + + +moseq_train.PreProcessing + + +moseq_train.PreProcessing - + + +moseq_train.PCAFit + + +moseq_train.PCAFit + + + + -moseq_train.PoseEstimationMethod->moseq_train.KeypointSet - +moseq_train.PreProcessing->moseq_train.PCAFit + - - -moseq_train.LatentDimension - - -moseq_train.LatentDimension + + +moseq_train.PreProcessing.Video + + +moseq_train.PreProcessing.Video - + + +moseq_train.PreProcessing->moseq_train.PreProcessing.Video + + + -moseq_train.Bodyparts - - -moseq_train.Bodyparts +moseq_train.PreFitTask + + +moseq_train.PreFitTask - - -moseq_train.PCATask - - -moseq_train.PCATask - + + +moseq_train.PreFitTask->moseq_train.PreFit + + + +moseq_train.PCAFit->moseq_train.LatentDimension + - - -moseq_train.Bodyparts->moseq_train.PCATask - + + +moseq_train.PCAFit->moseq_train.PreFitTask + - - -moseq_train.PCAPrep - - -moseq_train.PCAPrep + + +moseq_train.FullFitTask + + +moseq_train.FullFitTask - - -moseq_train.PCAFit - - -moseq_train.PCAFit + + +moseq_train.PCAFit->moseq_train.FullFitTask + + + + +moseq_train.Bodyparts + + +moseq_train.Bodyparts - - -moseq_train.PCAPrep->moseq_train.PCAFit - + + +moseq_train.PCATask + + +moseq_train.PCATask + - - -moseq_train.KeypointSet->moseq_train.KeypointSet.VideoFile - - - -moseq_train.KeypointSet->moseq_train.Bodyparts - + + +moseq_train.Bodyparts->moseq_train.PCATask + - - -moseq_train.FullFitTask - - -moseq_train.FullFitTask + + +moseq_train.KeypointSet.VideoFile + + +moseq_train.KeypointSet.VideoFile + + +moseq_train.PCATask->moseq_train.PreProcessing + + - + moseq_train.FullFit - - -moseq_train.FullFit + + +moseq_train.FullFit - - -moseq_train.FullFitTask->moseq_train.FullFit - - moseq_train.SelectedFullFit - -moseq_train.SelectedFullFit + +moseq_train.SelectedFullFit - + moseq_train.FullFit->moseq_train.SelectedFullFit - + - - -moseq_train.PreFitTask - - -moseq_train.PreFitTask - - - - + -moseq_train.PreFit - - -moseq_train.PreFit +moseq_train.PoseEstimationMethod + + +moseq_train.PoseEstimationMethod - - -moseq_train.PreFitTask->moseq_train.PreFit - + + +moseq_train.KeypointSet + + +moseq_train.KeypointSet + - - -moseq_train.PCATask->moseq_train.PCAPrep - - + -moseq_train.PCAFit->moseq_train.LatentDimension - +moseq_train.PoseEstimationMethod->moseq_train.KeypointSet + - + -moseq_train.PCAFit->moseq_train.FullFitTask - +moseq_train.KeypointSet->moseq_train.Bodyparts + - + -moseq_train.PCAFit->moseq_train.PreFitTask - +moseq_train.KeypointSet->moseq_train.KeypointSet.VideoFile + + + + +moseq_train.FullFitTask->moseq_train.FullFit + \ No newline at end of file diff --git a/notebooks/tutorial_pipeline.py b/notebooks/tutorial_pipeline.py index 60c7171f..0f86f32e 100644 --- a/notebooks/tutorial_pipeline.py +++ b/notebooks/tutorial_pipeline.py @@ -4,7 +4,7 @@ from element_lab import lab from element_animal import subject from element_session import session_with_datetime as session -from element_moseq import moseq_train, moseq_infer +from element_moseq import moseq_train, moseq_infer, moseq_report from element_animal.subject import Subject from element_lab.lab import Source, Lab, Protocol, User, Project @@ -38,7 +38,15 @@ def get_kpms_processed_data_dir() -> str: return None -__all__ = ["lab", "subject", "session", "moseq_train", "moseq_infer", "Device"] +__all__ = [ + "lab", + "subject", + "session", + "moseq_train", + "moseq_infer", + "moseq_report", + "Device", +] # Activate schemas ------------- @@ -80,3 +88,4 @@ class Device(dj.Lookup): moseq_train.activate(db_prefix + "moseq_train", linking_module=__name__) moseq_infer.activate(db_prefix + "moseq_infer", linking_module=__name__) +moseq_report.activate(db_prefix + "moseq_report", linking_module=__name__) diff --git a/pyproject.toml b/pyproject.toml index 115b9f1a..f4a38cf9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "element-moseq" -version = "0.3.1" +version = "1.0.0" description = "Keypoint-MoSeq DataJoint Element" readme = "README.md" license = {text = "MIT"} @@ -24,12 +24,7 @@ dependencies = [ "ipykernel>=6.0.1", "ipywidgets", "opencv-python", - "keypoint-moseq==0.4.8", - "dynamax==0.1.4", - "jax==0.4.13", - "jaxlib==0.4.13", - "jaxtyping==0.2.14", - "jupyter_bokeh", + "keypoint-moseq @ git+https://github.com/dattalab/keypoint-moseq/", ] [project.optional-dependencies]