From 778d76329b3e1510f5051b24d0a83c063c1369ed Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 26 Aug 2025 19:44:55 +0100 Subject: [PATCH 01/61] feat: several fixes to config yml overwrite and add feat : add removal of outlier keypoints, fix generation and udpate of config yml file, rename table from PCAPre to PreProcessing, add new attribute, add 3-part make function --- element_moseq/moseq_infer.py | 6 +- element_moseq/moseq_train.py | 270 ++++++++++++++++++------- element_moseq/readers/kpms_reader.py | 285 ++++++++++++--------------- 3 files changed, 318 insertions(+), 243 deletions(-) diff --git a/element_moseq/moseq_infer.py b/element_moseq/moseq_infer.py index 9e38515f..6e5e7567 100644 --- a/element_moseq/moseq_infer.py +++ b/element_moseq/moseq_infer.py @@ -14,7 +14,7 @@ from matplotlib import pyplot as plt from . import moseq_train -from .readers.kpms_reader import load_kpms_dj_config +from .readers import kpms_reader schema = dj.schema() _linking_module = None @@ -300,9 +300,7 @@ def make(self, key): support for another format method, please reach out to us at `support@datajoint.com`." ) - kpms_dj_config = load_kpms_dj_config( - model_dir.parent.as_posix(), check_if_valid=True, build_indexes=True - ) + kpms_dj_config = kpms_reader.dj_load_config(model_dir.parent) if kpms_dj_config: data, metadata = format_data(coordinates, confidences, **kpms_dj_config) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index f905dbab..353d8a44 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -35,13 +35,12 @@ def activate( Args: train_schema_name (str): A string containing the name of the `moseq_train` schema. - infer_schema_name (str): A string containing the name of the `moseq_infer` schema. create_schema (bool): If True (default), schema will be created in the database. create_tables (bool): If True (default), tables related to the schema will be created in the database. linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema. Dependencies: Functions: - get_kpms_root_data_dir(): Returns absolute path for root data director(y/ies) + get_kpms_root_data_dir(): Returns absolute path for root data directory/ies. with all behavioral recordings, as (list of) string(s). get_kpms_processed_data_dir(): Optional. Returns absolute path for processed data. @@ -52,7 +51,7 @@ def activate( linking_module = importlib.import_module(linking_module) assert inspect.ismodule( linking_module - ), "The argument 'dependency' must be a module's name or a module" + ), "The argument 'dependency' must be a module's name or a module object" assert hasattr( linking_module, "get_kpms_root_data_dir" @@ -77,7 +76,7 @@ def get_kpms_root_data_dir() -> list: """Pulls relevant func from parent namespace to specify root data dir(s). It is recommended that all paths in DataJoint Elements stored as relative - paths, with respect to some user-configured "root" director(y/ies). The + paths, with respect to some user-configured "root" directory/ies. The root(s) may vary between data modalities and user machines. Returns a full path string or list of strings for possible root data directories. """ @@ -209,6 +208,7 @@ class PCATask(dj.Manual): definition = """ -> Bodyparts # Unique ID for each `Bodyparts` key --- + outlier_scale_factor=6 : int # Scale factor for outlier detection (default: 6) kpms_project_output_dir='' : varchar(255) # Keypoint-MoSeq project output directory, relative to root data directory task_mode='load' :enum('load','trigger') # Trigger or load the task @@ -216,7 +216,7 @@ class PCATask(dj.Manual): @schema -class PCAPrep(dj.Imported): +class PreProcessing(dj.Imported): """ Table to set up the Keypoint-MoSeq project output directory (`kpms_project_output_dir`) , creating the default `config.yml` and updating it in a new `kpms_dj_config.yml`. @@ -239,7 +239,7 @@ class PCAPrep(dj.Imported): frame_rates : longblob # List of the frame rates of the videos for model training """ - def make(self, key): + def make_fetch(self, key): """ Make function to: 1. Generate and update the `kpms_dj_config.yml` with both the videoset directory and the bodyparts. @@ -260,7 +260,7 @@ def make(self, key): 6. Create a copy of the kpms `config.yml` named `kpms_dj_config.yml` that will be updated with both the `video_dir` and bodyparts 7. Load keypoint data from the keypoint files found in the `kpset_dir` that will serve as the training set. 8. As a result of the keypoint loading, the coordinates and confidences scores are generated and will be used to format the data for modeling. - 9. Calculate the average frame rate and the frame rate list of the videoset from which the keypoint set is derived. This two attributes can be used to calculate the kappa value. + 9. Calculate the average frame rate and the frame rate list of the videoset from which the keypoint set is derived. These two attributes can be used to calculate the kappa value. 10. Insert the results of this `make` function into the table. """ @@ -283,8 +283,64 @@ def make(self, key): "kpms_project_output_dir", "task_mode" ) + outlier_scale_factor = (PCATask & key).fetch1("outlier_scale_factor") + + return ( + anterior_bodyparts, + posterior_bodyparts, + use_bodyparts, + pose_estimation_method, + kpset_dir, + video_paths, + video_ids, + kpms_project_output_dir, + task_mode, + outlier_scale_factor, + ) + + def make_compute( + self, + key, + anterior_bodyparts, + posterior_bodyparts, + use_bodyparts, + pose_estimation_method, + kpset_dir, + video_paths, + video_ids, + kpms_project_output_dir, + task_mode, + outlier_scale_factor, + ): + """Compute the outlier keypoints and interpolate them. + + Args: + key (dict): Primary key from the `PCATask` table. + anterior_bodyparts (list): List of anterior bodyparts. + posterior_bodyparts (list): List of posterior bodyparts. + + Raises: + NotImplementedError: `pose_estimation_method` is only supported for `deeplabcut`. + FileNotFoundError: No DLC config.(yml|yaml) found in `kpset_dir`. + Exception: Could not create outlier plot for `recording_name`. + + Logic: + 1. Load the project configuration for outlier detection + 2. Find outliers using medoid distance analysis + 3. Interpolate keypoints to fix outliers + 4. Update confidences for outlier points + 5. Store results + """ + + from keypoint_moseq import ( + find_medoid_distance_outliers, + interpolate_keypoints, + load_keypoints, + plot_medoid_distance_outliers, + ) + if task_mode == "trigger": - from keypoint_moseq import load_config, setup_project + from keypoint_moseq import setup_project try: kpms_project_output_dir = find_full_path( @@ -293,7 +349,7 @@ def make(self, key): except FileNotFoundError: kpms_project_output_dir = ( - get_kpms_processed_data_dir() / kpms_project_output_dir + Path(get_kpms_processed_data_dir()) / kpms_project_output_dir ) kpset_dir = find_full_path(get_kpms_root_data_dir(), kpset_dir) @@ -302,35 +358,25 @@ def make(self, key): ) if pose_estimation_method == "deeplabcut": + cfg = kpset_dir / "config.yaml" + if not cfg.exists(): + cfg = kpset_dir / "config.yml" + if not cfg.exists(): + raise FileNotFoundError( + f"No DLC config.(yml|yaml) found in {kpset_dir}" + ) + # base `config.yml` is created with task_mode='trigger' setup_project( project_dir=kpms_project_output_dir.as_posix(), - deeplabcut_config=(kpset_dir / "config.yaml") - or (kpset_dir / "config.yml"), + deeplabcut_config=cfg.as_posix(), ) + else: raise NotImplementedError( "Currently, `deeplabcut` is the only pose estimation method supported by this Element. Please reach out at `support@datajoint.com` if you use another method." ) - kpms_config = load_config( - kpms_project_output_dir.as_posix(), - check_if_valid=True, - build_indexes=False, - ) - - kpms_dj_config_kwargs_dict = dict( - video_dir=videos_dir.as_posix(), - anterior_bodyparts=anterior_bodyparts, - posterior_bodyparts=posterior_bodyparts, - use_bodyparts=use_bodyparts, - ) - kpms_config.update(**kpms_dj_config_kwargs_dict) - kpms_reader.generate_kpms_dj_config( - kpms_project_output_dir.as_posix(), **kpms_config - ) else: - from keypoint_moseq import load_keypoints - kpms_project_output_dir = find_full_path( get_kpms_processed_data_dir(), kpms_project_output_dir ) @@ -339,23 +385,96 @@ def make(self, key): get_kpms_root_data_dir(), Path(video_paths[0]).parent ) - coordinates, confidences, formatted_bodyparts = load_keypoints( + raw_coordinates, raw_confidences, formatted_bodyparts = load_keypoints( filepath_pattern=kpset_dir, format=pose_estimation_method ) + # compute FPS frame_rate_list = [] for fp, _ in zip(video_paths, video_ids): video_path = (find_full_path(get_kpms_root_data_dir(), fp)).as_posix() cap = cv2.VideoCapture(video_path) - frame_rate_list.append(int(cap.get(cv2.CAP_PROP_FPS))) + fps = float(cap.get(cv2.CAP_PROP_FPS)) cap.release() - average_frame_rate = int(np.mean(frame_rate_list)) + frame_rate_list.append(fps) + average_frame_rate = float(np.mean(frame_rate_list)) + + # Generate a copy of config.yml with the generated/updated info after it is known + kpms_reader.dj_generate_config( + project_dir=kpms_project_output_dir, + video_dir=str(videos_dir), + use_bodyparts=list(use_bodyparts), + anterior_bodyparts=list(anterior_bodyparts), + posterior_bodyparts=list(posterior_bodyparts), + outlier_scale_factor=float(outlier_scale_factor), + ) + kpms_reader.dj_update_config( + kpms_project_output_dir, + fps=average_frame_rate, + ) + + # Remove outlier keypoints + kpms_config = kpms_reader.dj_load_config(kpms_project_output_dir) + cleaned_coordinates = {} + cleaned_confidences = {} + + for recording_name in raw_coordinates: + raw_coords = raw_coordinates[recording_name].copy() + raw_conf = raw_confidences[recording_name].copy() + + # Find outliers using medoid distance analysis + outliers = find_medoid_distance_outliers( + raw_coords, outlier_scale_factor=outlier_scale_factor + ) + + # Interpolate keypoints to fix outliers + cleaned_coords = interpolate_keypoints(raw_coords, outliers["mask"]) + + # Update confidences for outlier points + cleaned_conf = np.where(outliers["mask"], 0, raw_conf) + + # Store results + cleaned_coordinates[recording_name] = cleaned_coords + cleaned_confidences[recording_name] = cleaned_conf + + # Plot outliers + if formatted_bodyparts is not None: + try: + plot_medoid_distance_outliers( + project_dir=kpms_project_output_dir.as_posix(), + recording_name=recording_name, + original_coordinates=raw_coords, + interpolated_coordinates=cleaned_coords, + outlier_mask=outliers["mask"], + outlier_thresholds=outliers["thresholds"], + **kpms_config, + ) + + except Exception as e: + print(f"Could not create outlier plot for {recording_name}: {e}") + + return ( + cleaned_coordinates, + cleaned_confidences, + formatted_bodyparts, + average_frame_rate, + frame_rate_list, + ) + def make_insert( + self, + key, + cleaned_coordinates, + cleaned_confidences, + formatted_bodyparts, + average_frame_rate, + frame_rate_list, + ): self.insert1( dict( **key, - coordinates=coordinates, - confidences=confidences, + coordinates=cleaned_coordinates, + confidences=cleaned_confidences, formatted_bodyparts=formatted_bodyparts, average_frame_rate=average_frame_rate, frame_rates=frame_rate_list, @@ -368,12 +487,12 @@ class PCAFit(dj.Computed): """Fit PCA model. Attributes: - PCAPrep (foreign key) : `PCAPrep` Key. + PreProcessing (foreign key) : `PreProcessing` Key. pca_fit_time (datetime) : datetime of the PCA fitting analysis. """ definition = """ - -> PCAPrep # `PCAPrep` Key + -> PreProcessing # `PreProcessing` Key --- pca_fit_time=NULL : datetime # datetime of the PCA fitting analysis """ @@ -383,7 +502,7 @@ def make(self, key): Make function to format the keypoint data, fit the PCA model, and store it as a `pca.p` file in the Keypoint-MoSeq project output directory. Args: - key (dict): `PCAPrep` Key + key (dict): `PreProcessing` Key Raises: @@ -400,13 +519,13 @@ def make(self, key): "kpms_project_output_dir", "task_mode" ) kpms_project_output_dir = ( - get_kpms_processed_data_dir() / kpms_project_output_dir + Path(get_kpms_processed_data_dir()) / kpms_project_output_dir ) - kpms_default_config = kpms_reader.load_kpms_dj_config( - kpms_project_output_dir.as_posix(), check_if_valid=True, build_indexes=True + kpms_default_config = kpms_reader.dj_load_config(kpms_project_output_dir) + coordinates, confidences = (PreProcessing & key).fetch1( + "coordinates", "confidences" ) - coordinates, confidences = (PCAPrep & key).fetch1("coordinates", "confidences") data, _ = format_data( **kpms_default_config, coordinates=coordinates, confidences=confidences ) @@ -464,37 +583,38 @@ def make(self, key): it sets the `latent_dimension` to the number of components that explain the specified variance and `variance_percentage` to the specified variance threshold. 4. Insert the results of this `make` function into the table. """ + + VARIANCE_THRESHOLD = 0.90 + from keypoint_moseq import load_pca kpms_project_output_dir = (PCATask & key).fetch1("kpms_project_output_dir") kpms_project_output_dir = ( - get_kpms_processed_data_dir() / kpms_project_output_dir + Path(get_kpms_processed_data_dir()) / kpms_project_output_dir ) pca_path = kpms_project_output_dir / "pca.p" - if pca_path: + if pca_path.exists(): pca = load_pca(kpms_project_output_dir.as_posix()) else: raise FileNotFoundError( f"No pca model (`pca.p`) found in the project directory {kpms_project_output_dir}" ) - variance_threshold = 0.90 - cs = np.cumsum( pca.explained_variance_ratio_ ) # explained_variance_ratio_ndarray of shape (n_components,) - if cs[-1] < variance_threshold: + if cs[-1] < VARIANCE_THRESHOLD: latent_dimension = len(cs) variance_percentage = cs[-1] * 100 latent_dim_desc = ( f"All components together only explain {cs[-1]*100}% of variance." ) else: - latent_dimension = (cs > variance_threshold).nonzero()[0].min() + 1 - variance_percentage = variance_threshold * 100 - latent_dim_desc = f">={variance_threshold*100}% of variance explained by {(cs>variance_threshold).nonzero()[0].min()+1} components." + latent_dimension = (cs > VARIANCE_THRESHOLD).nonzero()[0].min() + 1 + variance_percentage = VARIANCE_THRESHOLD * 100 + latent_dim_desc = f">={VARIANCE_THRESHOLD*100}% of variance explained by {(cs>VARIANCE_THRESHOLD).nonzero()[0].min()+1} components." self.insert1( dict( @@ -525,7 +645,7 @@ class PreFitTask(dj.Manual): pre_num_iterations : int # Number of Gibbs sampling iterations to run in the model pre-fitting --- model_name : varchar(100) # Name of the model to be loaded if `task_mode='load'` - task_mode='load' :enum('trigger','load')# 'load': load computed analysis results, 'trigger': trigger computation + task_mode='load' :enum('load','trigger')# 'load': load computed analysis results, 'trigger': trigger computation pre_fit_desc='' : varchar(1000) # User-defined description of the pre-fitting task """ @@ -592,31 +712,31 @@ def make(self, key): "model_name", ) if task_mode == "trigger": - kpms_dj_config = kpms_reader.load_kpms_dj_config( - kpms_project_output_dir.as_posix(), - check_if_valid=True, - build_indexes=True, - ) - kpms_dj_config.update( - dict(latent_dim=int(pre_latent_dim), kappa=float(pre_kappa)) - ) - kpms_reader.generate_kpms_dj_config( - kpms_project_output_dir.as_posix(), **kpms_dj_config + # Update the existing kpms_dj_config.yml with new latent_dim and kappa values + kpms_reader.dj_update_config( + kpms_project_output_dir, + latent_dim=int(pre_latent_dim), + kappa=float(pre_kappa), ) + # Load the updated config for use in model fitting + kpms_dj_config = kpms_reader.dj_load_config(kpms_project_output_dir) + pca_path = kpms_project_output_dir / "pca.p" - if pca_path: + if pca_path.exists(): pca = load_pca(kpms_project_output_dir.as_posix()) else: raise FileNotFoundError( f"No pca model (`pca.p`) found in the project directory {kpms_project_output_dir}" ) - coordinates, confidences = (PCAPrep & key).fetch1( + coordinates, confidences = (PreProcessing & key).fetch1( "coordinates", "confidences" ) - data, metadata = format_data(coordinates, confidences, **kpms_dj_config) + data, metadata = format_data( + coordinates=coordinates, confidences=confidences, **kpms_dj_config + ) model = init_model(data=data, metadata=metadata, pca=pca, **kpms_dj_config) @@ -742,30 +862,28 @@ def make(self, key): "model_name", ) if task_mode == "trigger": - kpms_dj_config = kpms_reader.load_kpms_dj_config( - kpms_project_output_dir.as_posix(), - check_if_valid=True, - build_indexes=True, - ) - kpms_dj_config.update( - dict(latent_dim=int(full_latent_dim), kappa=float(full_kappa)) - ) - kpms_reader.generate_kpms_dj_config( - kpms_project_output_dir.as_posix(), **kpms_dj_config + kpms_reader.dj_update_config( + kpms_project_output_dir, + latent_dim=int(full_latent_dim), + kappa=float(full_kappa), ) + kpms_dj_config = kpms_reader.dj_load_config(kpms_project_output_dir) + pca_path = kpms_project_output_dir / "pca.p" - if pca_path: + if pca_path.exists(): pca = load_pca(kpms_project_output_dir.as_posix()) else: raise FileNotFoundError( f"No pca model (`pca.p`) found in the project directory {kpms_project_output_dir}" ) - coordinates, confidences = (PCAPrep & key).fetch1( + coordinates, confidences = (PreProcessing & key).fetch1( "coordinates", "confidences" ) - data, metadata = format_data(coordinates, confidences, **kpms_dj_config) + data, metadata = format_data( + coordinates=coordinates, confidences=confidences, **kpms_dj_config + ) model = init_model(data=data, metadata=metadata, pca=pca, **kpms_dj_config) model = update_hypparams( model, kappa=float(full_kappa), latent_dim=int(full_latent_dim) diff --git a/element_moseq/readers/kpms_reader.py b/element_moseq/readers/kpms_reader.py index 4664fdcf..7f3a5367 100644 --- a/element_moseq/readers/kpms_reader.py +++ b/element_moseq/readers/kpms_reader.py @@ -1,5 +1,7 @@ import logging import os +from pathlib import Path +from typing import Any, Dict, Union import jax.numpy as jnp import yaml @@ -7,181 +9,138 @@ logger = logging.getLogger("datajoint") -def generate_kpms_dj_config(output_dir, **kwargs): - """This function mirrors the behavior of the `generate_config` function from the `keypoint_moseq` - package. Nonetheless, it produces a duplicate of the initial configuration file, titled - `kpms_dj_config.yml`, in the output directory to maintain the integrity of the original file. - This replicated file accommodates any customized project settings, with default configurations - utilized unless specified differently via keyword arguments. +DJ_CONFIG = "kpms_dj_config.yml" +BASE_CONFIG = "config.yml" - Args: - output_dir (str): Directory containing the `kpms_dj_config.yml` that will be generated. - kwargs (dict): Custom project settings. - """ - def _build_yaml(sections, comments): - text_blocks = [] - for title, data in sections: - centered_title = f" {title} ".center(50, "=") - text_blocks.append(f"\n\n{'#'}{centered_title}{'#'}") - for key, value in data.items(): - text = yaml.dump({key: value}).strip("\n") - if key in comments: - text = f"\n{'#'} {comments[key]}\n{text}" - text_blocks.append(text) - return "\n".join(text_blocks) - - def _update_dict(new, original): - return {k: new[k] if k in new else v for k, v in original.items()} - - hypperams = _update_dict( - kwargs, - { - "error_estimator": {"slope": -0.5, "intercept": 0.25}, - "obs_hypparams": { - "sigmasq_0": 0.1, - "sigmasq_C": 0.1, - "nu_sigma": 1e5, - "nu_s": 5, - }, - "ar_hypparams": { - "latent_dim": 10, - "nlags": 3, - "S_0_scale": 0.01, - "K_0_scale": 10.0, - }, - "trans_hypparams": { - "num_states": 100, - "gamma": 1e3, - "alpha": 5.7, - "kappa": 1e6, - }, - "cen_hypparams": {"sigmasq_loc": 0.5}, - }, - ) - - hypperams = {k: _update_dict(kwargs, v) for k, v in hypperams.items()} - - anatomy = _update_dict( - kwargs, - { - "bodyparts": ["BODYPART1", "BODYPART2", "BODYPART3"], - "use_bodyparts": ["BODYPART1", "BODYPART2", "BODYPART3"], - "skeleton": [ - ["BODYPART1", "BODYPART2"], - ["BODYPART2", "BODYPART3"], - ], - "anterior_bodyparts": ["BODYPART1"], - "posterior_bodyparts": ["BODYPART3"], - }, - ) - - other = _update_dict( - kwargs, - { - "recording_name_suffix": "", - "verbose": False, - "conf_pseudocount": 1e-3, - "video_dir": "", - "keypoint_colormap": "autumn", - "whiten": True, - "fix_heading": False, - "seg_length": 10000, - }, - ) - - fitting = _update_dict( - kwargs, - { - "added_noise_level": 0.1, - "PCA_fitting_num_frames": 1000000, - "conf_threshold": 0.5, - # 'kappa_scan_target_duration': 12, - # 'kappa_scan_min': 1e2, - # 'kappa_scan_max': 1e12, - # 'num_arhmm_scan_iters': 50, - # 'num_arhmm_final_iters': 200, - # 'num_kpslds_scan_iters': 50, - # 'num_kpslds_final_iters': 500 - }, - ) - - comments = { - "verbose": "whether to print progress messages during fitting", - "keypoint_colormap": "colormap used for visualization; see `matplotlib.cm.get_cmap` for options", - "added_noise_level": "upper bound of uniform noise added to the data during initial AR-HMM fitting; this is used to regularize the model", - "PCA_fitting_num_frames": "number of frames used to fit the PCA model during initialization", - "video_dir": "directory with videos from which keypoints were derived (used for crowd movies)", - "recording_name_suffix": "suffix used to match videos to recording names; this can usually be left empty (see `util.find_matching_videos` for details)", - "bodyparts": "used to access columns in the keypoint data", - "skeleton": "used for visualization only", - "use_bodyparts": "determines the subset of bodyparts to use for modeling and the order in which they are represented", - "anterior_bodyparts": "used to initialize heading", - "posterior_bodyparts": "used to initialize heading", - "seg_length": "data are broken up into segments to parallelize fitting", - "trans_hypparams": "transition hyperparameters", - "ar_hypparams": "autoregressive hyperparameters", - "obs_hypparams": "keypoint observation hyperparameters", - "cen_hypparams": "centroid movement hyperparameters", - "error_estimator": "parameters to convert neural net likelihoods to error size priors", - "save_every_n_iters": "frequency for saving model snapshots during fitting; if 0 only final state is saved", - "kappa_scan_target_duration": "target median syllable duration (in frames) for choosing kappa", - "whiten": "whether to whiten principal components; used to initialize the latent pose trajectory `x`", - "conf_threshold": "used to define outliers for interpolation when the model is initialized", - "conf_pseudocount": "pseudocount used regularize neural network confidences", - "fix_heading": "whether to keep the heading angle fixed; this should only be True if the pose is constrained to a narrow range of angles, e.g. a headfixed mouse.", - } - - sections = [ - ("ANATOMY", anatomy), - ("FITTING", fitting), - ("HYPER PARAMS", hypperams), - ("OTHER", other), - ] - - with open(os.path.join(output_dir, "kpms_dj_config.yml"), "w") as f: - f.write(_build_yaml(sections, comments)) - - -def load_kpms_dj_config(output_dir, check_if_valid=True, build_indexes=True): - """ - This function mirrors the functionality of the `load_config` function from the `keypoint_moseq` - package. Similarly, this function loads the `kpms_dj_config.yml` from the output directory. +def _dj_config_path(project_dir: Union[str, os.PathLike]) -> str: + return str(Path(project_dir) / DJ_CONFIG) + - Args: - output_dir (str): Directory containing the `kpms_dj_config.yml` that will be loaded. - check_if_valid (bool): default=True. Check if the config is valid using :py:func:`keypoint_moseq.io.check_config_validity` - build_indexes (bool): default=True. Add keys `"anterior_idxs"` and `"posterior_idxs"` to the config. Each maps to a jax array indexing the elements of `config["anterior_bodyparts"]` and `config["posterior_bodyparts"]` by their order in `config["use_bodyparts"]` +def _base_config_path(project_dir: Union[str, os.PathLike]) -> str: + return str(Path(project_dir) / BASE_CONFIG) - Returns: - kpms_dj_config (dict): configuration settings + +def _check_config_validity_like_upstream(config: Dict[str, Any]) -> bool: + """ + Minimal mirror of keypoint_moseq.io.check_config_validity logic that matters + for anatomy consistency (anterior/posterior must be subset of use_bodyparts). """ + errors = [] + for bp in config.get("anterior_bodyparts", []): + if bp not in config.get("use_bodyparts", []): + errors.append( + f"ACTION REQUIRED: `anterior_bodyparts` contains {bp} " + "which is not one of the options in `use_bodyparts`." + ) + for bp in config.get("posterior_bodyparts", []): + if bp not in config.get("use_bodyparts", []): + errors.append( + f"ACTION REQUIRED: `posterior_bodyparts` contains {bp} " + "which is not one of the options in `use_bodyparts`." + ) + if errors: + for e in errors: + print(e) + return False + return True + + +def dj_generate_config(project_dir: str, **kwargs) -> str: + """ + Generate or refresh `/kpms_dj_config.yml`. - from keypoint_moseq import check_config_validity + Behavior: + - If the DJ config doesn't exist, start from the **base** `/config.yml` + created by upstream `setup_project`, then overlay kwargs and write DJ config. + - If the DJ config exists, load it, overlay kwargs, and rewrite it. - config_path = os.path.join(output_dir, "kpms_dj_config.yml") + Returns the path to `kpms_dj_config.yml`. + """ + project_dir = str(project_dir) + base_cfg_path = _base_config_path(project_dir) + dj_cfg_path = _dj_config_path(project_dir) + + if os.path.exists(dj_cfg_path): + with open(dj_cfg_path, "r") as f: + cfg = yaml.safe_load(f) or {} + else: + if not os.path.exists(base_cfg_path): + raise FileNotFoundError( + f"Missing base config at {base_cfg_path}. Run upstream setup_project first." + ) + with open(base_cfg_path, "r") as f: + cfg = yaml.safe_load(f) or {} + + # Upstream uses shallow updates for top-level keys in generate_config. + # We follow that (simple `dict.update`); nested blocks can be passed explicitly. + cfg.update(kwargs) + + # Upstream ensures skeleton exists; we do the same. + if "skeleton" not in cfg or cfg["skeleton"] is None: + cfg["skeleton"] = [] + + with open(dj_cfg_path, "w") as f: + yaml.safe_dump(cfg, f, sort_keys=False) + return dj_cfg_path + + +def dj_load_config( + project_dir: str, check_if_valid: bool = True, build_indexes: bool = True +) -> Dict[str, Any]: + """ + Load `/kpms_dj_config.yml`. - with open(config_path, "r") as f: - kpms_dj_config = yaml.safe_load(f) + Mirrors keypoint_moseq.io.load_config behavior: + - check_if_valid -> anatomy subset checks + - build_indexes -> adds jax arrays 'anterior_idxs' and 'posterior_idxs' + indexing into 'use_bodyparts' by order. + """ + dj_cfg_path = _dj_config_path(project_dir) + if not os.path.exists(dj_cfg_path): + raise FileNotFoundError( + f"Missing DJ config at {dj_cfg_path}. Create it with dj_generate_config()." + ) + + with open(dj_cfg_path, "r") as f: + cfg = yaml.safe_load(f) or {} if check_if_valid: - check_config_validity(kpms_dj_config) + _check_config_validity_like_upstream( + cfg + ) # readthedocs source mirrors this logic. :contentReference[oaicite:0]{index=0} if build_indexes: - kpms_dj_config["anterior_idxs"] = jnp.array( - [ - kpms_dj_config["use_bodyparts"].index(bp) - for bp in kpms_dj_config["anterior_bodyparts"] - ] - ) - kpms_dj_config["posterior_idxs"] = jnp.array( - [ - kpms_dj_config["use_bodyparts"].index(bp) - for bp in kpms_dj_config["posterior_bodyparts"] - ] + anterior = cfg.get("anterior_bodyparts", []) + posterior = cfg.get("posterior_bodyparts", []) + use_bps = cfg.get("use_bodyparts", []) + cfg["anterior_idxs"] = jnp.array( + [use_bps.index(bp) for bp in anterior] + ) # same indexing approach as upstream. :contentReference[oaicite:1]{index=1} + cfg["posterior_idxs"] = jnp.array([use_bps.index(bp) for bp in posterior]) + + if "skeleton" not in cfg or cfg["skeleton"] is None: + cfg["skeleton"] = [] + + return cfg + + +def dj_update_config(project_dir: str, **kwargs) -> Dict[str, Any]: + """ + Update `kpms_dj_config.yml` with provided top-level kwargs (same pattern as + keypoint_moseq.io.update_config), then rewrite the file and return the dict. + """ + dj_cfg_path = _dj_config_path(project_dir) + if not os.path.exists(dj_cfg_path): + raise FileNotFoundError( + f"Missing DJ config at {dj_cfg_path}. Create it with dj_generate_config()." ) - if "skeleton" not in kpms_dj_config or kpms_dj_config["skeleton"] is None: - kpms_dj_config["skeleton"] = [] + with open(dj_cfg_path, "r") as f: + cfg = yaml.safe_load(f) or {} + + cfg.update(kwargs) - return kpms_dj_config + with open(dj_cfg_path, "w") as f: + yaml.safe_dump(cfg, f, sort_keys=False) + return cfg From 558b6f5885baf55fcb0a4a34fab8b3b570b504b2 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 26 Aug 2025 19:49:08 +0100 Subject: [PATCH 02/61] Update CHANGELOG and bump version --- CHANGELOG.md | 7 +++++++ element_moseq/version.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 29c076d8..134138ab 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,13 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention. +## [0.4.0] - 2025-08-26 ++ Fix - Fix generation of `kpms_dj_config.yml` in `kpms_reader` to use `dj_load_config` and `dj_update_config` functions ++ Fix - Rename `PCAPrep` to `PreProcessing` ++ Feat - Add new attribute `outlier_scale_factor` in `PCATask` table ++ Feat - Add feature to remove outlier keypoints in `PreProcessing` table ++ Fix - `moseq_train` and `moseq_infer` to use `dj_load_config` and `dj_update_config` functions + ## [0.3.2] - 2025-08-25 + Feat - modernize packaging and environment management migrating from `setup.py` to `pyproject.toml`and `env.yml` + Fix - JAX compatibility issues diff --git a/element_moseq/version.py b/element_moseq/version.py index 931952fd..afa728c9 100644 --- a/element_moseq/version.py +++ b/element_moseq/version.py @@ -2,4 +2,4 @@ Package metadata """ -__version__ = "0.3.2" +__version__ = "0.4.0" From 442a718b3a5a084d215fb2438f514630370f4b63 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 26 Aug 2025 19:50:36 +0100 Subject: [PATCH 03/61] update CHANGELOG --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 134138ab..3fc60f73 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and + Feat - Add new attribute `outlier_scale_factor` in `PCATask` table + Feat - Add feature to remove outlier keypoints in `PreProcessing` table + Fix - `moseq_train` and `moseq_infer` to use `dj_load_config` and `dj_update_config` functions ++ Feat - Refactor `PreProcessing` table to use 3-part make function ## [0.3.2] - 2025-08-25 + Feat - modernize packaging and environment management migrating from `setup.py` to `pyproject.toml`and `env.yml` From fc68fa05a0d831deaba9c3a677ca05501a99aebb Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 4 Sep 2025 13:16:51 +0100 Subject: [PATCH 04/61] refactor:`Inference` paths --- element_moseq/moseq_infer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/element_moseq/moseq_infer.py b/element_moseq/moseq_infer.py index 6e5e7567..abccb761 100644 --- a/element_moseq/moseq_infer.py +++ b/element_moseq/moseq_infer.py @@ -267,10 +267,10 @@ def make(self, key): ) keypointset_dir = find_full_path(kpms_root, keypointset_dir) - inference_output_dir = os.path.join(model_dir, inference_output_dir) + inference_output_dir = Path(os.path.join(model_dir, inference_output_dir)) - if not os.path.exists(inference_output_dir): - os.makedirs(model_dir / inference_output_dir) + if not inference_output_dir.exists(): + inference_output_dir.mkdir(parents=True, exist_ok=True) pca_path = model_dir.parent / "pca.p" if pca_path: @@ -370,8 +370,8 @@ def make(self, key): # load results results = load_results( - project_dir=Path(inference_output_dir).parent, - model_name=Path(inference_output_dir).parts[-1], + project_dir=inference_output_dir.parent, + model_name=inference_output_dir.parts[-1], ) # extract syllables from results From e10b7718a12eb4718de964682958853cd89de823 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 4 Sep 2025 13:36:25 +0100 Subject: [PATCH 05/61] feat(VideoFile): add new attributes for downstream statistics --- element_moseq/moseq_train.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index 353d8a44..3284adc6 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -168,6 +168,8 @@ class VideoFile(dj.Part): video_id : int # Unique ID for each video corresponding to each keypoint data file, relative to root data directory --- video_path : varchar(1000) # Filepath of each video from which the keypoints are derived, relative to root data directory + group_label='' : varchar(100) # Assign a group label (such as “mutant” or “wildtype”) to each recording. Relevant for performing group-wise comparisons. + video_duration=0 : int # Duration of each video in minutes (if not provided, it will be automatically calculated in `PreProcessing`). """ From 2ded339fd449dd52e2cc1c1262585f6b300b6d43 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 4 Sep 2025 13:37:45 +0100 Subject: [PATCH 06/61] feat(PCATask): add new attribute and update docstrings --- element_moseq/moseq_train.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index 3284adc6..f932add4 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -200,19 +200,27 @@ class Bodyparts(dj.Manual): @schema class PCATask(dj.Manual): """ - Staging table to define the PCA task and its output directory. + Define the Principal Component Analysis (PCA) task for dimensionality reduction of keypoint data. + + This table defines the parameters for the PCA preprocessing step, which is a prerequisite for both + stages of the Keypoint-MoSeq training pipeline. The PCA step reduces the dimensionality of the + keypoint data by projecting it onto the principal components that capture the most variance in + the pose dynamics. This dimensionality reduction is essential for efficient model training and + helps identify the optimal latent dimension for the subsequent AR-HMM and Keypoint-SLDS model fitting. Attributes: Bodyparts (foreign key) : Unique ID for each `Bodyparts` key + outlier_scale_factor (int) : Scale factor for outlier detection in keypoint data (default: 6) kpms_project_output_dir (str) : Keypoint-MoSeq project output directory, relative to root data directory + task_mode (enum) : 'load' to load existing results, 'trigger' to compute new PCA """ definition = """ -> Bodyparts # Unique ID for each `Bodyparts` key --- - outlier_scale_factor=6 : int # Scale factor for outlier detection (default: 6) + outlier_scale_factor=6 : int # Scale factor for outlier detection in keypoint data (default: 6) kpms_project_output_dir='' : varchar(255) # Keypoint-MoSeq project output directory, relative to root data directory - task_mode='load' :enum('load','trigger') # Trigger or load the task + task_mode='load' :enum('load','trigger') # 'load' to load existing results, 'trigger' to compute new PCA """ From cbe49445b3a2a435876c886171f52582b13ff885 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 4 Sep 2025 13:38:53 +0100 Subject: [PATCH 07/61] docs(PreProcessing): update docstrings --- element_moseq/moseq_train.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index f932add4..0efc92ad 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -228,14 +228,21 @@ class PCATask(dj.Manual): @schema class PreProcessing(dj.Imported): """ - Table to set up the Keypoint-MoSeq project output directory (`kpms_project_output_dir`) , creating the default `config.yml` and updating it in a new `kpms_dj_config.yml`. + Preprocess keypoint data by cleaning outliers and setting up the Keypoint-MoSeq project configuration. + + This table handles the initial data preprocessing step that prepares keypoint data for the PCA and + subsequent model fitting stages. It performs outlier detection and removal using medoid distance + analysis, interpolates missing keypoints, and sets up the project configuration with video paths, + bodypart specifications, and frame rate information. This preprocessing also calculates and updates + video duration metadata in the KeypointSet.VideoFile table. This preprocessing is essential for + ensuring data quality and proper model initialization. Attributes: PCATask (foreign key) : Unique ID for each `PCATask` key. - coordinates (longblob) : Dictionary mapping filenames to keypoint coordinates as ndarrays of shape (n_frames, n_bodyparts, 2[or 3]). - confidences (longblob) : Dictionary mapping filenames to `likelihood` scores as ndarrays of shape (n_frames, n_bodyparts). + coordinates (longblob) : Dictionary mapping filenames to cleaned keypoint coordinates as ndarrays of shape (n_frames, n_bodyparts, 2[or 3]). + confidences (longblob) : Dictionary mapping filenames to updated likelihood scores as ndarrays of shape (n_frames, n_bodyparts). formatted_bodyparts (longblob) : List of bodypart names. The order of the names matches the order of the bodyparts in `coordinates` and `confidences`. - average_frame_rate (float) : Average frame rate of the videos for model training. + average_frame_rate (float) : Average frame rate of the videos for model training (used for kappa calculation). frame_rates (longblob) : List of the frame rates of the videos for model training. """ @@ -245,7 +252,7 @@ class PreProcessing(dj.Imported): coordinates : longblob # Dictionary mapping filenames to keypoint coordinates as ndarrays of shape (n_frames, n_bodyparts, 2[or 3]) confidences : longblob # Dictionary mapping filenames to `likelihood` scores as ndarrays of shape (n_frames, n_bodyparts) formatted_bodyparts : longblob # List of bodypart names. The order of the names matches the order of the bodyparts in `coordinates` and `confidences`. - average_frame_rate : float # Average frame rate of the videos for model training + average_frame_rate : float # Average frame rate of the videos for model training (used for kappa calculation). frame_rates : longblob # List of the frame rates of the videos for model training """ @@ -322,7 +329,7 @@ def make_compute( task_mode, outlier_scale_factor, ): - """Compute the outlier keypoints and interpolate them. + """Compute the outlier keypoints, interpolate them, and extract video metadata. Args: key (dict): Primary key from the `PCATask` table. @@ -339,7 +346,8 @@ def make_compute( 2. Find outliers using medoid distance analysis 3. Interpolate keypoints to fix outliers 4. Update confidences for outlier points - 5. Store results + 5. Calculate video frame rates and update video durations in KeypointSet.VideoFile table + 6. Store results """ from keypoint_moseq import ( From aa28e67e0bc4e35c2435b688b9c52fb2d875e0ac Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 4 Sep 2025 13:40:08 +0100 Subject: [PATCH 08/61] feat(PreProcessing): add computation of FPS and video duration --- element_moseq/moseq_train.py | 39 ++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index 0efc92ad..3b688d4d 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -407,14 +407,22 @@ def make_compute( filepath_pattern=kpset_dir, format=pose_estimation_method ) - # compute FPS + # compute FPS and video duration frame_rate_list = [] - for fp, _ in zip(video_paths, video_ids): + video_duration_list = [] + for fp, video_id in zip(video_paths, video_ids): video_path = (find_full_path(get_kpms_root_data_dir(), fp)).as_posix() cap = cv2.VideoCapture(video_path) fps = float(cap.get(cv2.CAP_PROP_FPS)) + frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) cap.release() + + # Calculate duration in minutes + duration_minutes = (frame_count / fps) / 60.0 + frame_rate_list.append(fps) + video_duration_list.append((video_id, int(duration_minutes))) + average_frame_rate = float(np.mean(frame_rate_list)) # Generate a copy of config.yml with the generated/updated info after it is known @@ -477,6 +485,7 @@ def make_compute( formatted_bodyparts, average_frame_rate, frame_rate_list, + video_duration_list, ) def make_insert( @@ -487,7 +496,9 @@ def make_insert( formatted_bodyparts, average_frame_rate, frame_rate_list, + video_duration_list, ): + # Insert the main preprocessing results self.insert1( dict( **key, @@ -499,6 +510,30 @@ def make_insert( ) ) + # Update video durations in KeypointSet.VideoFile table + for video_id, duration_minutes in video_duration_list: + try: + # Check if the video record exists + video_key = {"kpset_id": key["kpset_id"], "video_id": video_id} + if KeypointSet.VideoFile & video_key: + # For Manual tables, we need to get the existing record and update it + existing_record = (KeypointSet.VideoFile & video_key).fetch1() + + # Create updated record with new video_duration + updated_record = dict(existing_record) + updated_record["video_duration"] = duration_minutes + + # Delete the old record and insert the updated one + KeypointSet.VideoFile.update1(updated_record) + else: + logger.warning(f"Video record not found for video_id {video_id}") + + except Exception as e: + logger.warning( + f"Warning: Could not update video duration for video_id {video_id}: {e}" + ) + # Continue processing even if update fails + @schema class PCAFit(dj.Computed): From 398281603e7d4888b99423011c64bbc26e64b926 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 4 Sep 2025 13:40:45 +0100 Subject: [PATCH 09/61] docs(PCAFit): update docstring --- element_moseq/moseq_train.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index 3b688d4d..cf01e46e 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -537,7 +537,13 @@ def make_insert( @schema class PCAFit(dj.Computed): - """Fit PCA model. + """Fit Principal Component Analysis (PCA) model for dimensionality reduction of keypoint data. + + This table computes the PCA model that reduces the dimensionality of the keypoint data by projecting + it onto the principal components that capture the most variance in the pose dynamics. The fitted PCA + model is essential for both stages of the Keypoint-MoSeq training pipeline, as it provides the + low-dimensional representation of pose trajectories that will be used in the AR-HMM and Keypoint-SLDS + model fitting. Attributes: PreProcessing (foreign key) : `PreProcessing` Key. From 57d51171703ca28d5560fb4034241d24173d918e Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 4 Sep 2025 13:41:17 +0100 Subject: [PATCH 10/61] docs(LatentDimension): udpate docstring --- element_moseq/moseq_train.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index cf01e46e..4e18d6bd 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -602,10 +602,13 @@ def make(self, key): @schema class LatentDimension(dj.Imported): """ - Determine the latent dimension as part of the autoregressive hyperparameters (`ar_hypparams`) for the model fitting. - The objective of the analysis is to inform the user about the number of principal components needed to explain a - 90% variance threshold. Subsequently, the decision on how many components to utilize for the model fitting is left - to the user. + Determine the optimal latent dimension for model fitting based on variance explained by PCA components. + + This table analyzes the fitted PCA model to determine the number of principal components needed to + explain a specified variance threshold (default: 90%). This analysis helps users make informed decisions + about the latent dimension parameter for both stages of the Keypoint-MoSeq training pipeline. The + recommended approach is to use the number of components that explain 90% of the variance, or a maximum + of 10 dimensions, whichever is lower. Attributes: PCAFit (foreign key) : `PCAFit` Key. From 3e0c60ad78be5473d0ed5470cde777e24947bc27 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 4 Sep 2025 13:42:07 +0100 Subject: [PATCH 11/61] docs(PreFitTask, PreFit): update docstrings --- element_moseq/moseq_train.py | 26 +++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index 4e18d6bd..fecdb936 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -690,21 +690,27 @@ def make(self, key): @schema class PreFitTask(dj.Manual): - """Insert the parameters for the model (AR-HMM) pre-fitting. + """Define parameters for Stage 1: Auto-Regressive Hidden Markov Model (AR-HMM) pre-fitting. + + This table defines the parameters for the first stage of the two-stage Keypoint-MoSeq training + approach. The pre-fitting stage focuses on learning temporal behavioral dynamics and discrete + syllable states without incorporating spatial pose dynamics. This stage is designed for rapid + hyperparameter exploration, particularly for finding optimal kappa values that yield desired + syllable durations. Attributes: PCAFit (foreign key) : `PCAFit` task. pre_latent_dim (int) : Latent dimension to use for the model pre-fitting. - pre_kappa (int) : Kappa value to use for the model pre-fitting. - pre_num_iterations (int) : Number of Gibbs sampling iterations to run in the model pre-fitting. + pre_kappa (int) : Kappa value to use for the model pre-fitting (controls syllable duration). + pre_num_iterations (int) : Number of Gibbs sampling iterations to run in the model pre-fitting (typically 10-50). pre_fit_desc(varchar) : User-defined description of the pre-fitting task. """ definition = """ -> PCAFit # `PCAFit` Key - pre_latent_dim : int # Latent dimension to use for the model pre-fitting - pre_kappa : int # Kappa value to use for the model pre-fitting - pre_num_iterations : int # Number of Gibbs sampling iterations to run in the model pre-fitting + pre_latent_dim : int # Latent dimension to use for the model pre-fitting. + pre_kappa : int # Kappa value to use for the model pre-fitting (controls syllable duration). + pre_num_iterations : int # Number of Gibbs sampling iterations to run in the model pre-fitting (typically 10-50). --- model_name : varchar(100) # Name of the model to be loaded if `task_mode='load'` task_mode='load' :enum('load','trigger')# 'load': load computed analysis results, 'trigger': trigger computation @@ -714,7 +720,13 @@ class PreFitTask(dj.Manual): @schema class PreFit(dj.Computed): - """Fit AR-HMM model. + """Stage 1: Fit Auto-Regressive Hidden Markov Model (AR-HMM) for initial behavioral syllable discovery. + + This is the first stage of the two-stage Keypoint-MoSeq training approach. The PreFit step focuses + exclusively on learning the temporal dynamics and discrete behavioral states (syllables) without + incorporating spatial pose dynamics. This stage is computationally efficient and allows for rapid + exploration of hyperparameters (particularly kappa for syllable duration) before the expensive + full model fitting. Attributes: PreFitTask (foreign key) : `PreFitTask` Key. From 5bbdd4f3913ac898b2c5f4512356bce9c75baff9 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 4 Sep 2025 13:44:13 +0100 Subject: [PATCH 12/61] feat(PreFit): add `estimate_sigmasq_loc` from the latest KPMS version --- element_moseq/moseq_train.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index fecdb936..891b3ec6 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -786,6 +786,7 @@ def make(self, key): "model_name", ) if task_mode == "trigger": + from keypoint_moseq import estimate_sigmasq_loc # Update the existing kpms_dj_config.yml with new latent_dim and kappa values kpms_reader.dj_update_config( @@ -812,6 +813,17 @@ def make(self, key): coordinates=coordinates, confidences=confidences, **kpms_dj_config ) + kpms_reader.dj_update_config( + kpms_project_output_dir, + sigmasq_loc=estimate_sigmasq_loc( + data["Y"], data["mask"], filter_size=int(kpms_dj_config["fps"]) + ), + ) + + kpms_dj_config = kpms_reader.dj_load_config( + project_dir=kpms_project_output_dir + ) + model = init_model(data=data, metadata=metadata, pca=pca, **kpms_dj_config) model = update_hypparams( From 185e78d1ff1dcd69a3e418954a9318c5e66aba1c Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 4 Sep 2025 13:45:17 +0100 Subject: [PATCH 13/61] fix(PreFit): change folder name to match primary attributes instead of datetime --- element_moseq/moseq_train.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index 891b3ec6..0cde990b 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -830,14 +830,23 @@ def make(self, key): model, kappa=float(pre_kappa), latent_dim=int(pre_latent_dim) ) + dj_model_name = { + "latent_dim": int(pre_latent_dim), + "pre_kappa": float(pre_kappa), + "pre_num_iterations": int(pre_num_iterations), + } + start_time = datetime.now(timezone.utc) model, model_name = fit_model( model=model, + model_name=dj_model_name, data=data, metadata=metadata, project_dir=kpms_project_output_dir.as_posix(), ar_only=True, num_iters=pre_num_iterations, + generate_progress_plots=True, # saved to {project_dir}/{model_name}/plots/ + save_every_n_iters=25, ) end_time = datetime.now(timezone.utc) From 5a0c254f8c6d65cbb432703e07c70754dce314fc Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 4 Sep 2025 13:45:55 +0100 Subject: [PATCH 14/61] docs(FullFitTask): update docstring --- element_moseq/moseq_train.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index 0cde990b..985db4ef 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -868,23 +868,29 @@ def make(self, key): @schema class FullFitTask(dj.Manual): - """Insert the parameters for the full (Keypoint-SLDS model) fitting. - The full model will generally require a lower value of kappa to yield the same target syllable durations. + """Define parameters for Stage 2: Full Keypoint Switching Linear Dynamical System (Keypoint-SLDS) model fitting. + + This table defines the parameters for the second stage of the two-stage Keypoint-MoSeq training approach. + The full fitting stage refines the model by incorporating all parameters including spatial pose dynamics, + centroid, heading, noise estimates, and continuous latent states. This stage builds upon the initial + behavioral structure discovered in the pre-fitting stage to create a complete behavioral model. + + Note: The full model will generally require a lower value of kappa to yield the same target syllable + durations compared to the pre-fitting stage, as the additional spatial dynamics provide more information. Attributes: PCAFit (foreign key) : `PCAFit` Key. full_latent_dim (int) : Latent dimension to use for the model full fitting. - full_kappa (int) : Kappa value to use for the model full fitting. - full_num_iterations (int) : Number of Gibbs sampling iterations to run in the model full fitting. + full_kappa (int) : Kappa value to use for the model full fitting (typically lower than pre-fit kappa). + full_num_iterations (int) : Number of Gibbs sampling iterations to run in the model full fitting (typically 200-500). full_fit_desc(varchar) : User-defined description of the model full fitting task. - """ definition = """ -> PCAFit # `PCAFit` Key full_latent_dim : int # Latent dimension to use for the model full fitting - full_kappa : int # Kappa value to use for the model full fitting - full_num_iterations : int # Number of Gibbs sampling iterations to run in the model full fitting + full_kappa : int # Kappa value to use for the model full fitting (typically lower than pre-fit kappa). + full_num_iterations : int # Number of Gibbs sampling iterations to run in the model full fitting (typically 200-500). --- model_name : varchar(100) # Name of the model to be loaded if `task_mode='load'` task_mode='load' :enum('load','trigger')# Trigger or load the task From 742249dbbbf88d545951fd78d8ce82ee4d4a472c Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 4 Sep 2025 13:46:28 +0100 Subject: [PATCH 15/61] docs(FullFit): udpate docstrings --- element_moseq/moseq_train.py | 45 ++++++++++++++++++++---------------- 1 file changed, 25 insertions(+), 20 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index 985db4ef..891ed9bd 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -900,7 +900,13 @@ class FullFitTask(dj.Manual): @schema class FullFit(dj.Computed): - """Fit the full (Keypoint-SLDS) model. + """Stage 2: Fit the complete Keypoint Switching Linear Dynamical System (Keypoint-SLDS) model. + + This is the second stage of the two-stage Keypoint-MoSeq training approach. The FullFit step refines + the model by incorporating all parameters including spatial pose dynamics, centroid, heading, noise + estimates, and continuous latent states. This stage builds upon the initial behavioral exploration + discovered in the pre-fitting stage to create a complete behavioral model with both temporal and + spatial dynamics. Attributes: FullFitTask (foreign key) : `FullFitTask` Key. @@ -919,25 +925,24 @@ def make(self, key): """ Make function to fit the full (keypoint-SLDS) model - Args: - key (dict): dictionary with the `FullFitTask` Key. - - Raises: - - High-level Logic: - 1. Fetch the `kpms_project_output_dir` and define the full path. - 2. Fetch the model parameters from the `FullFitTask` table. - 2. Update the `dj_config.yml` with the selected latent dimension and kappa for the full-fitting. - 3. Initialize and fit the full model in a new `model_name` directory. - 4. Load the pca model from file in the `kpms_project_output_dir`. - 5. Fetch the `coordinates` and `confidences` scores to format the data for the model initialization. - 6. Initialize the model that create a `model` dict containing states, parameters, hyperparameters, noise prior, and random seed. - 7. Update the model dict with the selected kappa for the Keypoint-SLDS fitting. - 8. Fit the Keypoint-SLDS model using the `full_num_iterations` and create a subdirectory in `kpms_project_output_dir` with the model's latest checkpoint file. - 8. Reindex syllable labels by their frequency in the most recent model snapshot in the checkpoint file. \ - This function permutes the states and parameters of a saved checkpoint so that syllables are labeled \ - in order of frequency (i.e. so that 0 is the most frequent, 1 is the second most, and so on). - 8. Calculate the duration of the model fitting computation and insert it in the `PreFit` table. + Args: + key (dict): dictionary with the `FullFitTask` Key. + + Raises: + + High-level Logic: + 1. Fetch the `kpms_project_output_dir` and define the full path. + 2. Fetch the model parameters from the `FullFitTask` table. + 3. Update the `dj_config.yml` with the selected latent dimension and kappa for the full-fitting. + 4. Load the pca model from file in the `kpms_project_output_dir`. + 5. Fetch the `coordinates` and `confidences` scores to format the data for the model initialization. + 6. Initialize the model that create a `model` dict containing states, parameters, hyperparameters, noise prior, and random seed. + 7. Update the model dict with the selected kappa for the Keypoint-SLDS fitting. + 8. Fit the Keypoint-SLDS model using the `full_num_iterations` and create a subdirectory in `kpms_project_output_dir` with the model's latest checkpoint file. + 9. Reindex syllable labels by their frequency in the most recent model snapshot in the checkpoint file. \ + This function permutes the states and parameters of a saved checkpoint so that syllables are labeled \ + in order of frequency (i.e. so that 0 is the most frequent, 1 is the second most, and so on). + 10. Calculate the duration of the model fitting computation and insert it in the `FullFit` table. """ from keypoint_moseq import ( fit_model, From 76015c3be36f099e3b8343ddf4b7e59a3b326a40 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 4 Sep 2025 13:48:43 +0100 Subject: [PATCH 16/61] feat(FullFit): add `estimate_sigmasq_loc` from the latest version of KPMS --- element_moseq/moseq_train.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index 891ed9bd..95561923 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -945,6 +945,7 @@ def make(self, key): 10. Calculate the duration of the model fitting computation and insert it in the `FullFit` table. """ from keypoint_moseq import ( + estimate_sigmasq_loc, fit_model, format_data, init_model, @@ -969,12 +970,14 @@ def make(self, key): ) if task_mode == "trigger": kpms_reader.dj_update_config( - kpms_project_output_dir, + project_dir=kpms_project_output_dir, latent_dim=int(full_latent_dim), kappa=float(full_kappa), ) - kpms_dj_config = kpms_reader.dj_load_config(kpms_project_output_dir) + kpms_dj_config = kpms_reader.dj_load_config( + project_dir=kpms_project_output_dir + ) pca_path = kpms_project_output_dir / "pca.p" if pca_path.exists(): @@ -990,6 +993,17 @@ def make(self, key): data, metadata = format_data( coordinates=coordinates, confidences=confidences, **kpms_dj_config ) + kpms_reader.dj_update_config( + project_dir=kpms_project_output_dir, + sigmasq_loc=estimate_sigmasq_loc( + data["Y"], data["mask"], filter_size=int(kpms_dj_config["fps"]) + ), + ) + + kpms_dj_config = kpms_reader.dj_load_config( + project_dir=kpms_project_output_dir + ) + model = init_model(data=data, metadata=metadata, pca=pca, **kpms_dj_config) model = update_hypparams( model, kappa=float(full_kappa), latent_dim=int(full_latent_dim) From b343e1c3004f5e53032ade39c5eec7e2a81c7a0e Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 4 Sep 2025 13:49:40 +0100 Subject: [PATCH 17/61] fix(FullFit): update folder name to primary attributes instead of datetime --- element_moseq/moseq_train.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index 95561923..e6ea445f 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -1009,9 +1009,16 @@ def make(self, key): model, kappa=float(full_kappa), latent_dim=int(full_latent_dim) ) + dj_model_name = { + "latent_dim": int(full_latent_dim), + "full_kappa": float(full_kappa), + "full_num_iterations": int(full_num_iterations), + } + start_time = datetime.now(timezone.utc) model, model_name = fit_model( model=model, + model_name=dj_model_name, data=data, metadata=metadata, project_dir=kpms_project_output_dir.as_posix(), @@ -1022,8 +1029,12 @@ def make(self, key): duration_seconds = (end_time - start_time).total_seconds() reindex_syllables_in_checkpoint( - kpms_project_output_dir.as_posix(), Path(model_name).parts[-1] + project_dir=kpms_project_output_dir.as_posix(), + model_name=Path(model_name).parts[-1], ) + + # Save of results will be applied during the Inference + else: duration_seconds = None From b2d20138b67a51f1eee42cdbb0112316cdc30ff1 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 4 Sep 2025 13:50:39 +0100 Subject: [PATCH 18/61] docs(SelectedFullFit): update docstring --- element_moseq/moseq_train.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index e6ea445f..51c933be 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -1052,7 +1052,12 @@ def make(self, key): @schema class SelectedFullFit(dj.Manual): - """Selected FullFit model for use in the inference pipeline""" + """Register selected FullFit models for use in the inference pipeline. + + This table allows users to select and register specific FullFit models (from Stage 2 of the training + pipeline) for use in downstream inference tasks. Users can provide descriptive names and descriptions + for their trained models to facilitate model management and selection for behavioral analysis on new data. + """ definition = """ -> FullFit From bbd3a015fbe399040703f0197b65a5491879b114 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 4 Sep 2025 14:51:24 +0100 Subject: [PATCH 19/61] fix(devcontainer): update Dockerfile, update python version for KPMS >=3.10 --- .devcontainer/Dockerfile | 29 ++++++++++++----------------- 1 file changed, 12 insertions(+), 17 deletions(-) diff --git a/.devcontainer/Dockerfile b/.devcontainer/Dockerfile index e8b61591..4ab1cae2 100644 --- a/.devcontainer/Dockerfile +++ b/.devcontainer/Dockerfile @@ -1,10 +1,13 @@ -FROM python:3.9-slim@sha256:5f0192a4f58a6ce99f732fe05e3b3d00f12ae62e183886bca3ebe3d202686c7f +ARG PY_VER=3.11 +ARG DISTRO=bullseye +FROM mcr.microsoft.com/devcontainers/python:${PY_VER}-${DISTRO} -ENV PATH /usr/local/bin:$PATH -ENV PYTHON_VERSION 3.9.17 +# Avoid warnings by switching to noninteractive +ENV DEBIAN_FRONTEND=noninteractive + +USER root RUN \ - adduser --system --disabled-password --shell /bin/bash vscode && \ # install docker apt-get update && \ apt-get install ca-certificates curl gnupg lsb-release -y && \ @@ -31,27 +34,19 @@ COPY ./ /tmp/element-moseq/ RUN \ # pipeline dependencies - apt-get update && \ - apt-get install -y gcc ffmpeg graphviz && \ - pip install ipywidgets && \ - pip install --no-cache-dir -e /tmp/element-moseq[kpms,elements,tests] && \ + apt-get install gcc g++ ffmpeg libsm6 libxext6 -y && \ + pip install --no-cache-dir -e /tmp/element-moseq[elements,tests] && \ # clean up - rm -rf /tmp/element-moseq/ && \ + rm -rf /tmp/element-moseq && \ apt-get clean -# Install CPU version for KPMS -RUN pip install "jax[cpu]==0.3.22" -f https://storage.googleapis.com/jax-releases/jax_releases.html - ENV DJ_HOST fakeservices.datajoint.io ENV DJ_USER root ENV DJ_PASS simple -ENV DATA_MOUNTPOINT /workspaces/element-moseq/example_data -ENV KPMS_ROOT_DATA_DIR $DATA_MOUNTPOINT/inbox -ENV KPMS_PROCESSED_DATA_DIR $DATA_MOUNTPOINT/outbox +ENV KPMS_ROOT_DATA_DIR /workspaces/element-moseq/example_data ENV DATABASE_PREFIX neuro_ USER vscode -CMD bash -c "sudo rm /var/run/docker.pid; sudo dockerd" -ENV LD_LIBRARY_PATH="/lib:/opt/conda/lib" +CMD bash -c "sudo rm /var/run/docker.pid; sudo dockerd" From 9bcf9a3ec1e5fe594d31d6eba6276e0851bd3342 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 4 Sep 2025 14:52:03 +0100 Subject: [PATCH 20/61] feat(pyproject): update version of KPMS from 0.4.8 to the latest version from source --- pyproject.toml | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 115b9f1a..fc3ea764 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,12 +24,7 @@ dependencies = [ "ipykernel>=6.0.1", "ipywidgets", "opencv-python", - "keypoint-moseq==0.4.8", - "dynamax==0.1.4", - "jax==0.4.13", - "jaxlib==0.4.13", - "jaxtyping==0.2.14", - "jupyter_bokeh", + "keypoint-moseq @ git+https://github.com/dattalab/keypoint-moseq/", ] [project.optional-dependencies] From e8e51f1e332c87ffca9155d885ee84fd466020eb Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 4 Sep 2025 15:20:40 +0100 Subject: [PATCH 21/61] update CHANGELOG --- CHANGELOG.md | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3fc60f73..28874b17 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,13 +3,23 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention. -## [0.4.0] - 2025-08-26 -+ Fix - Fix generation of `kpms_dj_config.yml` in `kpms_reader` to use `dj_load_config` and `dj_update_config` functions -+ Fix - Rename `PCAPrep` to `PreProcessing` +## [0.4.0] - 2025-09-04 + Feat - Add new attribute `outlier_scale_factor` in `PCATask` table + Feat - Add feature to remove outlier keypoints in `PreProcessing` table -+ Fix - `moseq_train` and `moseq_infer` to use `dj_load_config` and `dj_update_config` functions + Feat - Refactor `PreProcessing` table to use 3-part make function ++ Feat - Add new attribute `registered_model_name` and `registered_model_desc` in `SelectedFullFit` table ++ Feat - Add new attribute `group_label` in `VideoFile` table for downstream statistical analysis ++ Feat - Add new attributes `video_duration` in `VideoFile` table and `average_frame_rate` in `PreProcessing` table ++ Feat - Add new computation for FPS and video duration in `PreProcessing` to calculate `video_duration` and `average_frame_rate` ++ Feat - Add `sigmasq_loc` feature in `PreFit` and `FullFit` to automatically estimate sigmasq_loc (prior controlling the centroid movement across frames) ++ Fix - Fix generation of `kpms_dj_config.yml` in `kpms_reader` to use `dj_load_config` and `dj_update_config` functions ++ Fix - `moseq_train` and `moseq_infer` to use `dj_load_config` and `dj_update_config` functions ++ Fix - Rename `PCAPrep` to `PreProcessing` ++ Fix - Remove JAX dependencies from `pyproject.toml` ++ Fix - Update dockerfile to use Python 3.11 ++ Add - Update docstrings + + ## [0.3.2] - 2025-08-25 + Feat - modernize packaging and environment management migrating from `setup.py` to `pyproject.toml`and `env.yml` From 317c51a04ba2af1dc9a754ee968b97244acf7dab Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 4 Sep 2025 15:20:57 +0100 Subject: [PATCH 22/61] update CHANGELOG --- CHANGELOG.md | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 28874b17..1847309f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,9 +17,10 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and + Fix - Rename `PCAPrep` to `PreProcessing` + Fix - Remove JAX dependencies from `pyproject.toml` + Fix - Update dockerfile to use Python 3.11 ++ Fix - update folder name to primary attributes instead of datetime in `PreFit` and `FullFit` ++ Fix - Refactor `moseq_infer` paths + Add - Update docstrings - - ++ Add - Update pre-commit file ## [0.3.2] - 2025-08-25 + Feat - modernize packaging and environment management migrating from `setup.py` to `pyproject.toml`and `env.yml` From 911546ffabf274076ba6907e41a683bccfe83a3e Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 4 Sep 2025 15:33:54 +0100 Subject: [PATCH 23/61] fix(pyproject): update version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index fc3ea764..7d5cb059 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "element-moseq" -version = "0.3.1" +version = "0.4.0" description = "Keypoint-MoSeq DataJoint Element" readme = "README.md" license = {text = "MIT"} From 3f782b7079135589d624988302c138ffee0985ff Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 4 Sep 2025 15:57:34 +0100 Subject: [PATCH 24/61] Minor update to attribute defaults --- element_moseq/moseq_train.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index 51c933be..c9ca607c 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -712,7 +712,7 @@ class PreFitTask(dj.Manual): pre_kappa : int # Kappa value to use for the model pre-fitting (controls syllable duration). pre_num_iterations : int # Number of Gibbs sampling iterations to run in the model pre-fitting (typically 10-50). --- - model_name : varchar(100) # Name of the model to be loaded if `task_mode='load'` + model_name='' : varchar(100) # Name of the model to be loaded if `task_mode='load'` task_mode='load' :enum('load','trigger')# 'load': load computed analysis results, 'trigger': trigger computation pre_fit_desc='' : varchar(1000) # User-defined description of the pre-fitting task """ @@ -892,7 +892,7 @@ class FullFitTask(dj.Manual): full_kappa : int # Kappa value to use for the model full fitting (typically lower than pre-fit kappa). full_num_iterations : int # Number of Gibbs sampling iterations to run in the model full fitting (typically 200-500). --- - model_name : varchar(100) # Name of the model to be loaded if `task_mode='load'` + model_name='' : varchar(100) # Name of the model to be loaded if `task_mode='load'` task_mode='load' :enum('load','trigger')# Trigger or load the task full_fit_desc='' : varchar(1000) # User-defined description of the model full fitting task """ @@ -917,7 +917,7 @@ class FullFit(dj.Computed): definition = """ -> FullFitTask # `FullFitTask` Key --- - model_name : varchar(100) # Name of the model as "kpms_project_output_dir/model_name" + model_name='' : varchar(100) # Name of the model as "kpms_project_output_dir/model_name" full_fit_duration=NULL : float # Time duration (seconds) of the full fitting computation """ From 0684a290327d2ec1e40eb5503e1bb880e6d8b7a4 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Thu, 4 Sep 2025 18:04:28 +0100 Subject: [PATCH 25/61] fix model_name_str --- element_moseq/moseq_train.py | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index c9ca607c..3d596620 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -830,16 +830,12 @@ def make(self, key): model, kappa=float(pre_kappa), latent_dim=int(pre_latent_dim) ) - dj_model_name = { - "latent_dim": int(pre_latent_dim), - "pre_kappa": float(pre_kappa), - "pre_num_iterations": int(pre_num_iterations), - } + model_name_str = f"latent_dim_{int(pre_latent_dim)}_kappa_{float(pre_kappa)}_iters_{int(pre_num_iterations)}" start_time = datetime.now(timezone.utc) model, model_name = fit_model( model=model, - model_name=dj_model_name, + model_name=model_name_str, data=data, metadata=metadata, project_dir=kpms_project_output_dir.as_posix(), @@ -1009,16 +1005,12 @@ def make(self, key): model, kappa=float(full_kappa), latent_dim=int(full_latent_dim) ) - dj_model_name = { - "latent_dim": int(full_latent_dim), - "full_kappa": float(full_kappa), - "full_num_iterations": int(full_num_iterations), - } + model_name_str = f"latent_dim_{int(full_latent_dim)}_kappa_{float(full_kappa)}_iters_{int(full_num_iterations)}" start_time = datetime.now(timezone.utc) model, model_name = fit_model( model=model, - model_name=dj_model_name, + model_name=model_name_str, data=data, metadata=metadata, project_dir=kpms_project_output_dir.as_posix(), From 0650b364a86f2954fdc5eaf4068b9c0ec9a9de10 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 9 Sep 2025 13:46:38 +0100 Subject: [PATCH 26/61] review: update filename to `conda_env.yml` --- env.yml => conda_env.yml | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename env.yml => conda_env.yml (100%) diff --git a/env.yml b/conda_env.yml similarity index 100% rename from env.yml rename to conda_env.yml From cc504922d442ebc39b8ccb073b854f1f2b30ca1c Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 9 Sep 2025 13:47:28 +0100 Subject: [PATCH 27/61] feat: add `report` schema --- element_moseq/report.py | 315 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 315 insertions(+) create mode 100644 element_moseq/report.py diff --git a/element_moseq/report.py b/element_moseq/report.py new file mode 100644 index 00000000..b710abe9 --- /dev/null +++ b/element_moseq/report.py @@ -0,0 +1,315 @@ +import importlib +import inspect +import os +import pathlib +import tempfile +from pathlib import Path + +import datajoint as dj +import matplotlib.pyplot as plt +import numpy as np +from element_interface.utils import find_full_path + +from . import moseq_infer, moseq_train +from .plotting import viz_utils +from .readers import kpms_reader + +schema = dj.schema() +_linking_module = None +logger = dj.logger + + +def activate( + report_schema_name: str, + *, + create_schema: bool = True, + create_tables: bool = True, + linking_module: str = None, +): + """Activate this schema. + + Args: + report_schema_name (str): Schema name on the database server to activate the `moseq_infer` schema. + create_schema (bool): When True (default), create schema in the database if it + does not yet exist. + create_tables (bool): When True (default), create schema tables in the database + if they do not yet exist. + linking_module (str): A module (or name) containing the required dependencies. + """ + + if isinstance(linking_module, str): + linking_module = importlib.import_module(linking_module) + assert inspect.ismodule( + linking_module + ), "The argument 'dependency' must be a module's name or a module" + assert hasattr( + linking_module, "get_kpms_root_data_dir" + ), "The linking module must specify a lookup function for a root data directory" + + global _linking_module + _linking_module = linking_module + + # activate + schema.activate( + report_schema_name, + create_schema=create_schema, + create_tables=create_tables, + add_objects=_linking_module.__dict__, + ) + + +# ----------------------------- Table declarations ---------------------- + + +@schema +class PreProcessingReport(dj.Imported): + """Store the outlier keypoints plots that are generated in outbox by `moseq_train.PreProcessing`""" + + definition = """ + -> moseq_train.PreProcessing + video_id: int # ID of the matching video file + --- + recording_name: varchar(255) # Name of the recording + outlier_plot: attach # A plot of the outlier keypoints + """ + + def make(self, key): + # Resolve project dir + project_rel = (moseq_train.PCATask & key).fetch1("kpms_project_output_dir") + kpms_project_output_dir = ( + Path(moseq_train.get_kpms_processed_data_dir()) / project_rel + ) + + # Fetch video table info + video_paths, video_ids = (moseq_train.KeypointSet.VideoFile & key).fetch( + "video_path", "video_id" + ) + + # Get recording names from PreProcessing dict keys + coords = (moseq_train.PreProcessing & key).fetch1("coordinates") + recording_names = list(coords.keys()) + + # Build mapping recording_name -> video_id + rec2vid = viz_utils.build_recording_to_video_id( + recording_names=recording_names, + video_paths=list(video_paths), + video_ids=list(video_ids), + ) + + # Insert one row per recording that matched a video_id + for rec in recording_names: + vid = rec2vid.get(rec) + if vid is None: + dj.logger.warning( + f"[PreProcessingReport] No video_id match for recording '{rec}'. Skipping." + ) + continue + + plot_path = ( + kpms_project_output_dir + / "quality_assurance" + / "plots" + / "keypoint_distance_outliers" + / f"{rec}.png" + ) + + if not plot_path.exists(): + dj.logger.warning( + f"[PreProcessingReport] Outlier plot not found at {plot_path}. Skipping." + ) + continue + + self.insert1( + { + **key, + "video_id": int(vid), + "recording_name": rec, + "outlier_plot": plot_path.as_posix(), + } + ) + + +@schema +class PCAReport(dj.Computed): + """ + Plots the principal components (PCs) from a PCAFit. + """ + + definition = """ + -> moseq_train.LatentDimension + --- + scree_plot: attach # A cumulative scree plot. + pcs_plot: attach # A visualization of each Principal Component (PC). + """ + + def make(self, key): + # Generate and store plots for the user to choose the latent dimensions in the next step + from keypoint_moseq import load_pca + + kpms_project_output_dir = (moseq_train.PCATask & key).fetch1( + "kpms_project_output_dir" + ) + kpms_project_output_dir = ( + moseq_train.get_kpms_processed_data_dir() / kpms_project_output_dir + ) + kpms_dj_config = kpms_reader.dj_load_config(project_dir=kpms_project_output_dir) + + pca = load_pca(kpms_project_output_dir.as_posix()) + + # Modified version of plot_scree from keypoint_moseq + scree_fig = plt.figure() + num_pcs = len(pca.components_) + plt.plot(np.arange(num_pcs) + 1, np.cumsum(pca.explained_variance_ratio_)) + plt.xlabel("PCs") + plt.ylabel("Explained variance") + plt.gcf().set_size_inches((2.5, 2)) + plt.grid() + plt.tight_layout() + fname = f"{key['kpset_id']}_{key['bodyparts_id']}" + + # Modified version ofplot_pcs from keypoint_moseq to visualize components of PCs + pcs_fig = viz_utils.plot_pcs( + pca, + **kpms_dj_config, + interactive=False, + project_dir=kpms_project_output_dir, + ) + + tmpdir = tempfile.TemporaryDirectory() + + # plot variance summary + scree_path = pathlib.Path(tmpdir.name) / f"{fname}_scree_plot.png" + scree_fig.savefig(scree_path) + + # plot pcs + pcs_path = pathlib.Path(tmpdir.name) / f"{fname}_pcs_plot.png" + pcs_fig.savefig(pcs_path) + + # insert into table + self.insert1({**key, "scree_plot": scree_path, "pcs_plot": pcs_path}) + tmpdir.cleanup() + + +@schema +class PreFitReport(dj.Imported): + definition = """ + -> moseq_train.PreFit + --- + fitting_progress_pdf: attach # fitting_progress.pdf + """ + + def make(self, key): + prefit_model_name = (moseq_train.PreFit & key).fetch1("model_name") + prefit_model_dir = find_full_path( + moseq_train.get_kpms_processed_data_dir(), prefit_model_name + ) + prefit_output_dir = Path(prefit_model_dir) / "fitting_progress.pdf" + if os.path.exists(prefit_output_dir): + self.insert1({**key, "fitting_progress_pdf": prefit_output_dir}) + else: + raise FileNotFoundError( + f"PreFit fitting_progress.pdf not found at {prefit_output_dir}" + ) + + +@schema +class FullFitReport(dj.Imported): + definition = """ + -> moseq_train.FullFit + --- + fitting_progress_pdf: attach # fitting_progress.pdf + """ + + def make(self, key): + fullfit_model_name = (moseq_train.FullFit & key).fetch1("model_name") + fullfit_model_dir = find_full_path( + moseq_train.get_kpms_processed_data_dir(), fullfit_model_name + ) + fullfit_output_dir = Path(fullfit_model_dir) / "fitting_progress.pdf" + if os.path.exists(fullfit_output_dir): + self.insert1({**key, "fitting_progress_pdf": fullfit_output_dir}) + else: + raise FileNotFoundError( + f"FullFit fitting_progress.pdf not found at {fullfit_output_dir}" + ) + + +@schema +class InferenceReport(dj.Imported): + definition = """ + -> moseq_infer.Inference + --- + syllable_frequencies: attach + similarity_dendrogram_png: attach + similarity_dendrogram_pdf: attach + all_trajectories_gif: attach + all_trajectories_pdf: attach + """ + + class Trajectory(dj.Part): + definition = """ + -> master + syllable_id: int + --- + plot_gif: attach + plot_pdf: attach + grid_movie: attach + """ + + def make(self, key): + import imageio + + task_info = (moseq_infer.InferenceTask & key).fetch1() + model = (moseq_infer.Model & {"model_id": task_info["model_id"]}).fetch1() + + model_dir = find_full_path( + moseq_train.get_kpms_processed_data_dir(), model["model_dir"] + ) + output_dir = Path(model_dir) / task_info["inference_output_dir"] + + # Insert per-inference entry + self.insert1( + { + **key, + "syllable_frequencies": output_dir / "syllable_frequencies.png", + "similarity_dendrogram_png": output_dir / "similarity_dendrogram.png", + "similarity_dendrogram_pdf": output_dir / "similarity_dendrogram.pdf", + "all_trajectories_gif": output_dir + / "trajectory_plots" + / "all_trajectories.gif", + "all_trajectories_pdf": output_dir + / "trajectory_plots" + / "all_trajectories.pdf", + } + ) + + # Insert per-syllable visuals + for syllable in (moseq_infer.Inference.GridMoviesSampledInstances & key).fetch( + "syllable" + ): + video_mp4_path = output_dir / "grid_movies" / f"syllable{syllable}.mp4" + video_mp4_to_gif_path = ( + output_dir / "grid_movies" / f"syllable{syllable}_grid_movie.gif" + ) + reader = imageio.get_reader(video_mp4_path) + fps = reader.get_meta_data()["fps"] + writer = imageio.get_writer(video_mp4_to_gif_path, fps=fps, loop=0) + + for frame in reader: + writer.append_data(frame) + + writer.close() + + self.Trajectory.insert1( + { + **key, + "syllable_id": syllable, + "plot_gif": output_dir + / "trajectory_plots" + / f"syllable{syllable}.gif", + "plot_pdf": output_dir + / "trajectory_plots" + / f"syllable{syllable}.pdf", + "grid_movie": video_mp4_to_gif_path, + } + ) From 6961f1e59bed7425392d726832b4bc0f603bb778 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 9 Sep 2025 13:49:34 +0100 Subject: [PATCH 28/61] review: apply suggestions to `moseq_train` --- element_moseq/moseq_train.py | 102 +++++++++++++++++++---------------- 1 file changed, 56 insertions(+), 46 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index 3d596620..6a1c8c60 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -18,9 +18,7 @@ from .readers import kpms_reader schema = dj.schema() - _linking_module = None - logger = dj.logger @@ -168,8 +166,6 @@ class VideoFile(dj.Part): video_id : int # Unique ID for each video corresponding to each keypoint data file, relative to root data directory --- video_path : varchar(1000) # Filepath of each video from which the keypoints are derived, relative to root data directory - group_label='' : varchar(100) # Assign a group label (such as “mutant” or “wildtype”) to each recording. Relevant for performing group-wise comparisons. - video_duration=0 : int # Duration of each video in minutes (if not provided, it will be automatically calculated in `PreProcessing`). """ @@ -226,7 +222,7 @@ class PCATask(dj.Manual): @schema -class PreProcessing(dj.Imported): +class PreProcessing(dj.Computed): """ Preprocess keypoint data by cleaning outliers and setting up the Keypoint-MoSeq project configuration. @@ -253,9 +249,17 @@ class PreProcessing(dj.Imported): confidences : longblob # Dictionary mapping filenames to `likelihood` scores as ndarrays of shape (n_frames, n_bodyparts) formatted_bodyparts : longblob # List of bodypart names. The order of the names matches the order of the bodyparts in `coordinates` and `confidences`. average_frame_rate : float # Average frame rate of the videos for model training (used for kappa calculation). - frame_rates : longblob # List of the frame rates of the videos for model training """ + class Video(dj.Part): + definition = """ + -> master + video_name: varchar(255) + --- + video_duration=0 : int # Duration of each video in minutes (if not provided, it will be automatically calculated in `PreProcessing`). + frame_rate=0 : float # Frame rate of the video. + """ + def make_fetch(self, key): """ Make function to: @@ -376,14 +380,14 @@ def make_compute( ) if pose_estimation_method == "deeplabcut": - cfg = kpset_dir / "config.yaml" - if not cfg.exists(): - cfg = kpset_dir / "config.yml" - if not cfg.exists(): + from .readers.kpms_reader import _base_config_path + + cfg_path = _base_config_path(kpset_dir) + if not os.path.exists(cfg_path): raise FileNotFoundError( f"No DLC config.(yml|yaml) found in {kpset_dir}" ) - # base `config.yml` is created with task_mode='trigger' + cfg = Path(cfg_path) setup_project( project_dir=kpms_project_output_dir.as_posix(), deeplabcut_config=cfg.as_posix(), @@ -407,9 +411,9 @@ def make_compute( filepath_pattern=kpset_dir, format=pose_estimation_method ) - # compute FPS and video duration frame_rate_list = [] video_duration_list = [] + video_metadata_list = [] for fp, video_id in zip(video_paths, video_ids): video_path = (find_full_path(get_kpms_root_data_dir(), fp)).as_posix() cap = cv2.VideoCapture(video_path) @@ -423,6 +427,24 @@ def make_compute( frame_rate_list.append(fps) video_duration_list.append((video_id, int(duration_minutes))) + # Get video name for the Video part table + video_key = {"kpset_id": key["kpset_id"], "video_id": video_id} + if KeypointSet.VideoFile & video_key: + video_record = (KeypointSet.VideoFile & video_key).fetch1() + video_name = Path( + video_record["video_path"] + ).stem # Get filename without extension + video_metadata_list.append( + { + "video_id": video_id, + "video_name": video_name, + "video_duration": int(duration_minutes), + "frame_rate": fps, + } + ) + else: + logger.warning(f"Video record not found for video_id {video_id}") + average_frame_rate = float(np.mean(frame_rate_list)) # Generate a copy of config.yml with the generated/updated info after it is known @@ -434,13 +456,13 @@ def make_compute( posterior_bodyparts=list(posterior_bodyparts), outlier_scale_factor=float(outlier_scale_factor), ) - kpms_reader.dj_update_config( + kpms_reader.update_kpms_dj_config( kpms_project_output_dir, fps=average_frame_rate, ) # Remove outlier keypoints - kpms_config = kpms_reader.dj_load_config(kpms_project_output_dir) + kpms_config = kpms_reader.load_kpms_dj_config(kpms_project_output_dir) cleaned_coordinates = {} cleaned_confidences = {} @@ -486,6 +508,7 @@ def make_compute( average_frame_rate, frame_rate_list, video_duration_list, + video_metadata_list, ) def make_insert( @@ -497,6 +520,7 @@ def make_insert( average_frame_rate, frame_rate_list, video_duration_list, + video_metadata_list, ): # Insert the main preprocessing results self.insert1( @@ -506,33 +530,19 @@ def make_insert( confidences=cleaned_confidences, formatted_bodyparts=formatted_bodyparts, average_frame_rate=average_frame_rate, - frame_rates=frame_rate_list, ) ) - # Update video durations in KeypointSet.VideoFile table - for video_id, duration_minutes in video_duration_list: - try: - # Check if the video record exists - video_key = {"kpset_id": key["kpset_id"], "video_id": video_id} - if KeypointSet.VideoFile & video_key: - # For Manual tables, we need to get the existing record and update it - existing_record = (KeypointSet.VideoFile & video_key).fetch1() - - # Create updated record with new video_duration - updated_record = dict(existing_record) - updated_record["video_duration"] = duration_minutes - - # Delete the old record and insert the updated one - KeypointSet.VideoFile.update1(updated_record) - else: - logger.warning(f"Video record not found for video_id {video_id}") - - except Exception as e: - logger.warning( - f"Warning: Could not update video duration for video_id {video_id}: {e}" + # Insert video metadata into the Video part table + for video_metadata in video_metadata_list: + self.Video.insert1( + dict( + **key, + video_name=video_metadata["video_name"], + video_duration=video_metadata["video_duration"], + frame_rate=video_metadata["frame_rate"], ) - # Continue processing even if update fails + ) @schema @@ -581,7 +591,7 @@ def make(self, key): Path(get_kpms_processed_data_dir()) / kpms_project_output_dir ) - kpms_default_config = kpms_reader.dj_load_config(kpms_project_output_dir) + kpms_default_config = kpms_reader.load_kpms_dj_config(kpms_project_output_dir) coordinates, confidences = (PreProcessing & key).fetch1( "coordinates", "confidences" ) @@ -789,14 +799,14 @@ def make(self, key): from keypoint_moseq import estimate_sigmasq_loc # Update the existing kpms_dj_config.yml with new latent_dim and kappa values - kpms_reader.dj_update_config( + kpms_reader.update_kpms_dj_config( kpms_project_output_dir, latent_dim=int(pre_latent_dim), kappa=float(pre_kappa), ) # Load the updated config for use in model fitting - kpms_dj_config = kpms_reader.dj_load_config(kpms_project_output_dir) + kpms_dj_config = kpms_reader.load_kpms_dj_config(kpms_project_output_dir) pca_path = kpms_project_output_dir / "pca.p" if pca_path.exists(): @@ -813,14 +823,14 @@ def make(self, key): coordinates=coordinates, confidences=confidences, **kpms_dj_config ) - kpms_reader.dj_update_config( + kpms_reader.update_kpms_dj_config( kpms_project_output_dir, sigmasq_loc=estimate_sigmasq_loc( data["Y"], data["mask"], filter_size=int(kpms_dj_config["fps"]) ), ) - kpms_dj_config = kpms_reader.dj_load_config( + kpms_dj_config = kpms_reader.load_kpms_dj_config( project_dir=kpms_project_output_dir ) @@ -965,13 +975,13 @@ def make(self, key): "model_name", ) if task_mode == "trigger": - kpms_reader.dj_update_config( + kpms_reader.update_kpms_dj_config( project_dir=kpms_project_output_dir, latent_dim=int(full_latent_dim), kappa=float(full_kappa), ) - kpms_dj_config = kpms_reader.dj_load_config( + kpms_dj_config = kpms_reader.load_kpms_dj_config( project_dir=kpms_project_output_dir ) @@ -989,14 +999,14 @@ def make(self, key): data, metadata = format_data( coordinates=coordinates, confidences=confidences, **kpms_dj_config ) - kpms_reader.dj_update_config( + kpms_reader.update_kpms_dj_config( project_dir=kpms_project_output_dir, sigmasq_loc=estimate_sigmasq_loc( data["Y"], data["mask"], filter_size=int(kpms_dj_config["fps"]) ), ) - kpms_dj_config = kpms_reader.dj_load_config( + kpms_dj_config = kpms_reader.load_kpms_dj_config( project_dir=kpms_project_output_dir ) From fcd29eebaa866c251e73adecbddc6baf57837194 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 9 Sep 2025 13:50:35 +0100 Subject: [PATCH 29/61] refactor(moseq_infer) --- element_moseq/moseq_infer.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/element_moseq/moseq_infer.py b/element_moseq/moseq_infer.py index abccb761..4a19d74e 100644 --- a/element_moseq/moseq_infer.py +++ b/element_moseq/moseq_infer.py @@ -5,7 +5,6 @@ import importlib import inspect -import os from datetime import datetime, timezone from pathlib import Path @@ -163,6 +162,7 @@ class Inference(dj.Computed): definition = """ -> InferenceTask # `InferenceTask` key --- + file_h5 : attach # File path of the results.h5 file inference_duration=NULL : float # Time duration (seconds) of the inference computation """ @@ -185,6 +185,7 @@ class MotionSequence(dj.Part): latent_state : longblob # Inferred low-dim pose state (x). Low-dimensional representation of the animal's pose in each frame. These are similar to PCA scores, are modified to reflect the pose dynamics and noise estimates inferred by the model centroid : longblob # Inferred centroid (v). The centroid of the animal in each frame, as estimated by the model heading : longblob # Inferred heading (h). The heading of the animal in each frame, as estimated by the model + file_csv : attach # File path of the results.csv file """ class GridMoviesSampledInstances(dj.Part): @@ -267,7 +268,7 @@ def make(self, key): ) keypointset_dir = find_full_path(kpms_root, keypointset_dir) - inference_output_dir = Path(os.path.join(model_dir, inference_output_dir)) + inference_output_dir = Path(model_dir) / inference_output_dir if not inference_output_dir.exists(): inference_output_dir.mkdir(parents=True, exist_ok=True) @@ -300,7 +301,7 @@ def make(self, key): support for another format method, please reach out to us at `support@datajoint.com`." ) - kpms_dj_config = kpms_reader.dj_load_config(model_dir.parent) + kpms_dj_config = kpms_reader.load_kpms_dj_config(model_dir.parent) if kpms_dj_config: data, metadata = format_data(coordinates, confidences, **kpms_dj_config) @@ -401,7 +402,13 @@ def make(self, key): duration_seconds = None - self.insert1({**key, "inference_duration": duration_seconds}) + self.insert1( + { + **key, + "inference_duration": duration_seconds, + "file_h5": (inference_output_dir / "results.h5").as_posix(), + } + ) for result_idx, result in results.items(): self.MotionSequence.insert1( @@ -412,6 +419,9 @@ def make(self, key): "latent_state": result["latent_state"], "centroid": result["centroid"], "heading": result["heading"], + "file_csv": ( + inference_output_dir / "results_as_csv" / f"{result_idx}.csv" + ).as_posix(), } ) From 5d8f036d7eaa7fac3eb3193c2e656998da6f87b3 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 9 Sep 2025 13:56:37 +0100 Subject: [PATCH 30/61] feat(plotting): move and refactor `viz_utils` --- element_moseq/plotting/__init__.py | 0 element_moseq/plotting/viz_utils.py | 334 ++++++++++++++++++++++++++++ 2 files changed, 334 insertions(+) create mode 100644 element_moseq/plotting/__init__.py create mode 100644 element_moseq/plotting/viz_utils.py diff --git a/element_moseq/plotting/__init__.py b/element_moseq/plotting/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/element_moseq/plotting/viz_utils.py b/element_moseq/plotting/viz_utils.py new file mode 100644 index 00000000..51468428 --- /dev/null +++ b/element_moseq/plotting/viz_utils.py @@ -0,0 +1,334 @@ +# ---- Modified version of the viz functions from the main branch of keypoint_moseq ---- + +import os +import re +from difflib import SequenceMatcher +from pathlib import Path +from textwrap import fill +from typing import Dict, List, Optional, Tuple + +import matplotlib.pyplot as plt +import numpy as np +from jax_moseq.models.keypoint_slds import center_embedding +from keypoint_moseq.util import ( + get_distance_to_medoid, + get_edges, + plot_keypoint_traces, + plot_pcs_3D, +) + +_DLC_SUFFIX_RE = re.compile( + r"(?:DLC_[A-Za-z0-9]+[A-Za-z]+(?:\d+)?(?:[A-Za-z]+)?" # scorer-ish token + r"(?:\w+)*)" # optional extra blobs + r"(?:shuffle\d+)?" # shuffleN + r"(?:_\d+)?$" # _iter +) + + +def _normalize_name(name: str) -> str: + """ + Normalize a recording/video string for matching: + - lowercase, strip whitespace + - drop extension if present + - remove common DLC suffix blob (e.g., '...DLC_resnet50_...shuffle1_500000') + - collapse separators to single spaces + """ + s = name.lower().strip() + s = Path(s).stem + s = _DLC_SUFFIX_RE.sub("", s) # strip DLC tail if present + s = re.sub(r"[\s._-]+", " ", s).strip() + return s + + +def build_recording_to_video_id( + recording_names: List[str], + video_paths: List[str], + video_ids: List[int], + fuzzy_threshold: float = 0.80, +) -> Dict[str, Optional[int]]: + """ + Returns: {recording_name -> video_id or None if no good match} + Strategy: exact normalized stem match; if none, substring; then fuzzy. + """ + # candidate stems from videos + stems: List[Tuple[str, int]] = [ + (_normalize_name(Path(p).name), vid) for p, vid in zip(video_paths, video_ids) + ] + + mapping: Dict[str, Optional[int]] = {} + + for rec in recording_names: + nrec = _normalize_name(rec) + + # 1) exact normalized match + exact = [vid for stem, vid in stems if stem == nrec] + if exact: + mapping[rec] = exact[0] + continue + + # 2) substring either way (choose longest stem to disambiguate) + subs = [(stem, vid) for stem, vid in stems if nrec in stem or stem in nrec] + if subs: + subs.sort(key=lambda x: len(x[0]), reverse=True) + mapping[rec] = subs[0][1] + continue + + # 3) fuzzy best match + best_vid, best_ratio = None, 0.0 + for stem, vid in stems: + r = SequenceMatcher(None, nrec, stem).ratio() + if r > best_ratio: + best_ratio, best_vid = r, vid + mapping[rec] = best_vid if best_ratio >= fuzzy_threshold else None + + return mapping + + +def plot_medoid_distance_outliers( + project_dir: str, + recording_name: str, + original_coordinates: np.ndarray, + interpolated_coordinates: np.ndarray, + outlier_mask, + outlier_thresholds, + bodyparts: list[str], + **kwargs, +): + """Create and save a plot comparing distance-to-medoid for original vs. interpolated keypoints. + + Generates a multi-panel plot showing the distance from each keypoint to the medoid + position for both original and interpolated coordinates. The plot includes threshold + lines and shaded regions for outlier frames. Saves the figure to the QA plots + directory. + + Parameters + ------- + project_dir: str + Path to the project directory where the plot will be saved. + + recording_name: str + Name of the recording, used for the plot title and filename. + + original_coordinates: ndarray of shape (n_frames, n_keypoints, keypoint_dim) + Original keypoint coordinates before interpolation. + + interpolated_coordinates: ndarray of shape (n_frames, n_keypoints, keypoint_dim) + Keypoint coordinates after interpolation. + + outlier_mask: ndarray of shape (n_frames, n_keypoints) + Boolean mask indicating outlier keypoints (True = outlier). + + outlier_thresholds: ndarray of shape (n_keypoints,) + Distance thresholds for each keypoint above which points are considered outliers. + + bodyparts: list of str + Names of bodyparts corresponding to each keypoint. Must have length equal to + n_keypoints. + + **kwargs + Additional keyword arguments (ignored), usually overflow from **config(). + + Returns + ------- + None + The plot is saved to 'QA/plots/keypoint_distance_outliers/{recording_name}.png'. + """ + + plot_path = os.path.join( + project_dir, + "quality_assurance", + "plots", + "keypoint_distance_outliers", + f"{recording_name}.png", + ) + os.makedirs(os.path.dirname(plot_path), exist_ok=True) + + original_distances = get_distance_to_medoid( + original_coordinates + ) # (n_frames, n_keypoints) + interpolated_distances = get_distance_to_medoid( + interpolated_coordinates + ) # (n_frames, n_keypoints) + + fig = plot_keypoint_traces( + traces=[original_distances, interpolated_distances], + plot_title=recording_name, + bodyparts=bodyparts, + line_labels=["Original", "Interpolated"], + thresholds=outlier_thresholds, + shading_mask=outlier_mask, + ) + + fig.savefig(plot_path, dpi=300) + + plt.close() + print(f"Saved keypoint distance outlier plot for {recording_name} to {plot_path}.") + return fig + + +def plot_pcs( + pca, + *, + use_bodyparts, + skeleton, + keypoint_colormap="autumn", + keypoint_colors=None, + savefig=True, + project_dir=None, + scale=1, + plot_n_pcs=10, + axis_size=(2, 1.5), + ncols=5, + node_size=30.0, + line_width=2.0, + interactive=True, + **kwargs, +): + """ + Visualize the components of a fitted PCA model. + + For each PC, a subplot shows the mean pose (semi-transparent) along with a + perturbation of the mean pose in the direction of the PC. + + Parameters + ---------- + pca : :py:func:`sklearn.decomposition.PCA` + Fitted PCA model + + use_bodyparts : list of str + List of bodyparts to that are used in the model; used to index bodypart + names in the skeleton. + + skeleton : list + List of edges that define the skeleton, where each edge is a pair of + bodypart names. + + keypoint_colormap : str + Name of a matplotlib colormap to use for coloring the keypoints. + + keypoint_colors : array-like, shape=(num_keypoints,3), default=None + Color for each keypoint. If None, `keypoint_colormap` is used. If the + dtype is int, the values are assumed to be in the range 0-255, + otherwise they are assumed to be in the range 0-1. + + savefig : bool, True + Whether to save the figure to a file. If true, the figure is saved to + `{project_dir}/pcs-{xy/xz/yz}.pdf` (`xz` and `yz` are only included + for 3D data). + + project_dir : str, default=None + Path to the project directory. Required if `savefig` is True. + + scale : float, default=0.5 + Scale factor for the perturbation of the mean pose. + + plot_n_pcs : int, default=10 + Number of PCs to plot. + + axis_size : tuple of float, default=(2,1.5) + Size of each subplot in inches. + + ncols : int, default=5 + Number of columns in the figure. + + node_size : float, default=30.0 + Size of the keypoints in the figure. + + line_width: float, default=2.0 + Width of edges in skeleton + + interactive : bool, default=True + For 3D data, whether to generate an interactive 3D plot. + """ + k = len(use_bodyparts) + d = len(pca.mean_) // (k - 1) + + if keypoint_colors is None: + cmap = plt.cm.get_cmap(keypoint_colormap) + keypoint_colors = cmap(np.linspace(0, 1, k)) + + Gamma = np.array(center_embedding(k)) + edges = get_edges(use_bodyparts, skeleton) + plot_n_pcs = min(plot_n_pcs, pca.components_.shape[0]) + + magnitude = np.sqrt((pca.mean_**2).mean()) * scale + ymean = Gamma @ pca.mean_.reshape(k - 1, d) + ypcs = (pca.mean_ + magnitude * pca.components_).reshape(-1, k - 1, d) + ypcs = Gamma[np.newaxis] @ ypcs[:plot_n_pcs] + + if d == 2: + dims_list, names = [[0, 1]], ["xy"] + if d == 3: + dims_list, names = [[0, 1], [0, 2]], ["xy", "xz"] + + for dims, name in zip(dims_list, names): + nrows = int(np.ceil(plot_n_pcs / ncols)) + fig, axs = plt.subplots(nrows, ncols, sharex=True, sharey=True) + for i, ax in enumerate(axs.flat): + if i >= plot_n_pcs: + ax.axis("off") + continue + + for e in edges: + ax.plot( + *ymean[:, dims][e].T, + color=keypoint_colors[e[0]], + zorder=0, + alpha=0.25, + linewidth=line_width, + ) + ax.plot( + *ypcs[i][:, dims][e].T, + color="k", + zorder=2, + linewidth=line_width + 0.2, + ) + ax.plot( + *ypcs[i][:, dims][e].T, + color=keypoint_colors[e[0]], + zorder=3, + linewidth=line_width, + ) + + ax.scatter( + *ymean[:, dims].T, + c=keypoint_colors, + s=node_size, + zorder=1, + alpha=0.25, + linewidth=0, + ) + ax.scatter( + *ypcs[i][:, dims].T, + c=keypoint_colors, + s=node_size, + zorder=4, + edgecolor="k", + linewidth=0.2, + ) + + ax.set_title(f"PC {i+1}", fontsize=10) + ax.set_aspect("equal") + ax.axis("off") + + fig.set_size_inches((axis_size[0] * ncols, axis_size[1] * nrows)) + plt.tight_layout() + + if savefig: + assert project_dir is not None, fill( + "The `savefig` option requires a `project_dir`" + ) + plt.savefig(os.path.join(project_dir, f"pcs-{name}.pdf")) + plt.show() + + if interactive and d == 3: + plot_pcs_3D( + ymean, + ypcs, + edges, + keypoint_colormap, + project_dir if savefig else None, + node_size / 3, + line_width * 2, + ) + return fig From 0259b3abfb6ab279c2fbeb2b9b8fa3cd4a47fc57 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 9 Sep 2025 13:57:02 +0100 Subject: [PATCH 31/61] refactor(kpms_reader) --- element_moseq/readers/kpms_reader.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/element_moseq/readers/kpms_reader.py b/element_moseq/readers/kpms_reader.py index 7f3a5367..14ec2e49 100644 --- a/element_moseq/readers/kpms_reader.py +++ b/element_moseq/readers/kpms_reader.py @@ -18,10 +18,21 @@ def _dj_config_path(project_dir: Union[str, os.PathLike]) -> str: def _base_config_path(project_dir: Union[str, os.PathLike]) -> str: - return str(Path(project_dir) / BASE_CONFIG) - - -def _check_config_validity_like_upstream(config: Dict[str, Any]) -> bool: + """Return the path to the base config file, checking for both .yml and .yaml extensions.""" + project_path = Path(project_dir) + # Check for config.yml first (preferred) + config_yml = project_path / "config.yml" + if config_yml.exists(): + return str(config_yml) + # Fall back to config.yaml + config_yaml = project_path / "config.yaml" + if config_yaml.exists(): + return str(config_yaml) + # If neither exists, return the default (config.yml) + return str(config_yml) + + +def _check_config_validity(config: Dict[str, Any]) -> bool: """ Minimal mirror of keypoint_moseq.io.check_config_validity logic that matters for anatomy consistency (anterior/posterior must be subset of use_bodyparts). @@ -67,7 +78,8 @@ def dj_generate_config(project_dir: str, **kwargs) -> str: else: if not os.path.exists(base_cfg_path): raise FileNotFoundError( - f"Missing base config at {base_cfg_path}. Run upstream setup_project first." + f"Missing base config at {base_cfg_path}. Run upstream setup_project first. " + f"Expected either config.yml or config.yaml in {project_dir}." ) with open(base_cfg_path, "r") as f: cfg = yaml.safe_load(f) or {} @@ -85,7 +97,7 @@ def dj_generate_config(project_dir: str, **kwargs) -> str: return dj_cfg_path -def dj_load_config( +def load_kpms_dj_config( project_dir: str, check_if_valid: bool = True, build_indexes: bool = True ) -> Dict[str, Any]: """ @@ -106,7 +118,7 @@ def dj_load_config( cfg = yaml.safe_load(f) or {} if check_if_valid: - _check_config_validity_like_upstream( + _check_config_validity( cfg ) # readthedocs source mirrors this logic. :contentReference[oaicite:0]{index=0} @@ -125,7 +137,7 @@ def dj_load_config( return cfg -def dj_update_config(project_dir: str, **kwargs) -> Dict[str, Any]: +def update_kpms_dj_config(project_dir: str, **kwargs) -> Dict[str, Any]: """ Update `kpms_dj_config.yml` with provided top-level kwargs (same pattern as keypoint_moseq.io.update_config), then rewrite the file and return the dict. From 9e8bb3ec255c818761bfc99543bf76e967471ace Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 9 Sep 2025 13:57:24 +0100 Subject: [PATCH 32/61] update(tutorial_pipeline) --- notebooks/tutorial_pipeline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/notebooks/tutorial_pipeline.py b/notebooks/tutorial_pipeline.py index 60c7171f..ffae002d 100644 --- a/notebooks/tutorial_pipeline.py +++ b/notebooks/tutorial_pipeline.py @@ -4,7 +4,7 @@ from element_lab import lab from element_animal import subject from element_session import session_with_datetime as session -from element_moseq import moseq_train, moseq_infer +from element_moseq import moseq_train, moseq_infer, report from element_animal.subject import Subject from element_lab.lab import Source, Lab, Protocol, User, Project @@ -80,3 +80,4 @@ class Device(dj.Lookup): moseq_train.activate(db_prefix + "moseq_train", linking_module=__name__) moseq_infer.activate(db_prefix + "moseq_infer", linking_module=__name__) +report.activate(db_prefix + "report", linking_module=__name__) From 25b427f43952f14dfdb69f2dabf74a6dab85efc0 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 9 Sep 2025 14:06:06 +0100 Subject: [PATCH 33/61] minor fix in viz_utils --- element_moseq/plotting/viz_utils.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/element_moseq/plotting/viz_utils.py b/element_moseq/plotting/viz_utils.py index 51468428..14e35882 100644 --- a/element_moseq/plotting/viz_utils.py +++ b/element_moseq/plotting/viz_utils.py @@ -10,12 +10,8 @@ import matplotlib.pyplot as plt import numpy as np from jax_moseq.models.keypoint_slds import center_embedding -from keypoint_moseq.util import ( - get_distance_to_medoid, - get_edges, - plot_keypoint_traces, - plot_pcs_3D, -) +from keypoint_moseq.util import get_distance_to_medoid, get_edges, plot_keypoint_traces +from keypoint_moseq.viz import plot_pcs_3D _DLC_SUFFIX_RE = re.compile( r"(?:DLC_[A-Za-z0-9]+[A-Za-z]+(?:\d+)?(?:[A-Za-z]+)?" # scorer-ish token From 0a46a63f03baec066e97b41c42b20ae03e2c4258 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 9 Sep 2025 14:06:26 +0100 Subject: [PATCH 34/61] update images --- images/pipeline.svg | 502 ++++++++++++--------- images/pipeline_moseq_infer.svg | 112 ++--- images/pipeline_moseq_infer_and_report.svg | 162 +++++++ images/pipeline_moseq_train.svg | 254 ++++++----- images/pipeline_moseq_train_and_report.svg | 275 +++++++++++ 5 files changed, 927 insertions(+), 378 deletions(-) create mode 100644 images/pipeline_moseq_infer_and_report.svg create mode 100644 images/pipeline_moseq_train_and_report.svg diff --git a/images/pipeline.svg b/images/pipeline.svg index 0e1b1929..61ade29c 100644 --- a/images/pipeline.svg +++ b/images/pipeline.svg @@ -1,327 +1,425 @@ - - - + + + moseq_train.KeypointSet - -moseq_train.KeypointSet + +moseq_train.KeypointSet - + moseq_train.KeypointSet.VideoFile - - -moseq_train.KeypointSet.VideoFile + + +moseq_train.KeypointSet.VideoFile moseq_train.KeypointSet->moseq_train.KeypointSet.VideoFile - + - + moseq_train.Bodyparts - - -moseq_train.Bodyparts + + +moseq_train.Bodyparts moseq_train.KeypointSet->moseq_train.Bodyparts - + - + -moseq_train.FullFitTask - - -moseq_train.FullFitTask +moseq_train.PCATask + + +moseq_train.PCATask - - -moseq_train.FullFit - - -moseq_train.FullFit + + +moseq_train.PreProcessing + + +moseq_train.PreProcessing - + -moseq_train.FullFitTask->moseq_train.FullFit - +moseq_train.PCATask->moseq_train.PreProcessing + - + -moseq_train.PreFit - - -moseq_train.PreFit +moseq_train.SelectedFullFit + + +moseq_train.SelectedFullFit - - -moseq_train.PCAFit - - -moseq_train.PCAFit + + +moseq_infer.Model + + +moseq_infer.Model - + -moseq_train.PCAFit->moseq_train.FullFitTask - +moseq_train.SelectedFullFit->moseq_infer.Model + - - -moseq_train.LatentDimension - - -moseq_train.LatentDimension + + +report.PCAReport + + +report.PCAReport - - -moseq_train.PCAFit->moseq_train.LatentDimension - - - - -moseq_train.PreFitTask - - -moseq_train.PreFitTask + + +moseq_infer.VideoRecording.File + + +moseq_infer.VideoRecording.File - - -moseq_train.PCAFit->moseq_train.PreFitTask - + + +moseq_train.FullFit + + +moseq_train.FullFit + - - -moseq_infer.Model - - -moseq_infer.Model + + + +moseq_train.FullFit->moseq_train.SelectedFullFit + + + + +report.FullFitReport + + +report.FullFitReport - + + +moseq_train.FullFit->report.FullFitReport + + + -moseq_infer.InferenceTask - - -moseq_infer.InferenceTask +moseq_train.FullFitTask + + +moseq_train.FullFitTask - + -moseq_infer.Model->moseq_infer.InferenceTask - +moseq_train.FullFitTask->moseq_train.FullFit + - - -moseq_infer.Inference.GridMoviesSampledInstances - - -moseq_infer.Inference.GridMoviesSampledInstances + + +moseq_infer.Inference + + +moseq_infer.Inference - - -moseq_infer.Inference - - -moseq_infer.Inference + + +moseq_infer.Inference.MotionSequence + + +moseq_infer.Inference.MotionSequence - + -moseq_infer.InferenceTask->moseq_infer.Inference - +moseq_infer.Inference->moseq_infer.Inference.MotionSequence + - - -moseq_train.SelectedFullFit - - -moseq_train.SelectedFullFit + + +report.InferenceReport + + +report.InferenceReport - + -moseq_train.SelectedFullFit->moseq_infer.Model - +moseq_infer.Inference->report.InferenceReport + - - -moseq_train.PoseEstimationMethod - - -moseq_train.PoseEstimationMethod + + +moseq_infer.Inference.GridMoviesSampledInstances + + +moseq_infer.Inference.GridMoviesSampledInstances - + -moseq_train.PoseEstimationMethod->moseq_train.KeypointSet - +moseq_infer.Inference->moseq_infer.Inference.GridMoviesSampledInstances + + + + +moseq_train.PreFitTask + + +moseq_train.PreFitTask + + + + + +moseq_train.PreFit + + +moseq_train.PreFit + - - -moseq_train.PoseEstimationMethod->moseq_infer.InferenceTask - - + moseq_train.PreFitTask->moseq_train.PreFit - + - - -Device - - -Device + + +moseq_infer.InferenceTask + + +moseq_infer.InferenceTask + + +moseq_infer.Model->moseq_infer.InferenceTask + + - + moseq_infer.VideoRecording - - -moseq_infer.VideoRecording + + +moseq_infer.VideoRecording - + -Device->moseq_infer.VideoRecording - - - - -moseq_infer.VideoRecording.File - - -moseq_infer.VideoRecording.File - - - - - -moseq_train.PCATask - - -moseq_train.PCATask - +moseq_infer.VideoRecording->moseq_infer.VideoRecording.File + + + +moseq_infer.VideoRecording->moseq_infer.InferenceTask + - + moseq_train.Bodyparts->moseq_train.PCATask - + - - -moseq_train.PCAPrep - - -moseq_train.PCAPrep + + +session.Session + + +session.Session - - -moseq_train.PCAPrep->moseq_train.PCAFit - - - + -moseq_infer.Inference->moseq_infer.Inference.GridMoviesSampledInstances - +session.Session->moseq_infer.VideoRecording + - - -moseq_infer.Inference.MotionSequence - - -moseq_infer.Inference.MotionSequence + + +moseq_train.LatentDimension + + +moseq_train.LatentDimension - + -moseq_infer.Inference->moseq_infer.Inference.MotionSequence - +moseq_train.LatentDimension->report.PCAReport + - + + +report.InferenceReport.Trajectory + + +report.InferenceReport.Trajectory + + + + -moseq_infer.VideoRecording->moseq_infer.InferenceTask - +report.InferenceReport->report.InferenceReport.Trajectory + - + + +moseq_train.PCAFit + + +moseq_train.PCAFit + + + + -moseq_infer.VideoRecording->moseq_infer.VideoRecording.File - +moseq_train.PCAFit->moseq_train.FullFitTask + - + -moseq_train.FullFit->moseq_train.SelectedFullFit - +moseq_train.PCAFit->moseq_train.PreFitTask + + + + +moseq_train.PCAFit->moseq_train.LatentDimension + - + subject.Subject - - -subject.Subject + + +subject.Subject - - -session.Session - - -session.Session + + +subject.Subject->session.Session + + + + +moseq_train.PreProcessing.Video + + +moseq_train.PreProcessing.Video - - -subject.Subject->session.Session - + + +report.PreFitReport + + +report.PreFitReport + - - -moseq_train.PCATask->moseq_train.PCAPrep - - + -session.Session->moseq_infer.VideoRecording - +moseq_infer.InferenceTask->moseq_infer.Inference + + + + +Device + + +Device + + + + + +Device->moseq_infer.VideoRecording + + + + +moseq_train.PoseEstimationMethod + + +moseq_train.PoseEstimationMethod + + + + + +moseq_train.PoseEstimationMethod->moseq_train.KeypointSet + + + + +moseq_train.PoseEstimationMethod->moseq_infer.InferenceTask + + + + +report.PreProcessingReport + + +report.PreProcessingReport + + + + + +moseq_train.PreProcessing->moseq_train.PCAFit + + + + +moseq_train.PreProcessing->moseq_train.PreProcessing.Video + + + + +moseq_train.PreProcessing->report.PreProcessingReport + + + + +moseq_train.PreFit->report.PreFitReport + \ No newline at end of file diff --git a/images/pipeline_moseq_infer.svg b/images/pipeline_moseq_infer.svg index cd89e7a3..bc3d7103 100644 --- a/images/pipeline_moseq_infer.svg +++ b/images/pipeline_moseq_infer.svg @@ -1,98 +1,98 @@ - - - - - -moseq_infer.VideoRecording.File - - -moseq_infer.VideoRecording.File - - - - - -moseq_infer.Model - - -moseq_infer.Model - - - + + + - + moseq_infer.InferenceTask - - -moseq_infer.InferenceTask + + +moseq_infer.InferenceTask - - -moseq_infer.Model->moseq_infer.InferenceTask - - moseq_infer.Inference - - -moseq_infer.Inference + + +moseq_infer.Inference + + +moseq_infer.InferenceTask->moseq_infer.Inference + + - + moseq_infer.Inference.GridMoviesSampledInstances - - -moseq_infer.Inference.GridMoviesSampledInstances + + +moseq_infer.Inference.GridMoviesSampledInstances moseq_infer.Inference->moseq_infer.Inference.GridMoviesSampledInstances - + - + moseq_infer.Inference.MotionSequence - - -moseq_infer.Inference.MotionSequence + + +moseq_infer.Inference.MotionSequence moseq_infer.Inference->moseq_infer.Inference.MotionSequence - + - - -moseq_infer.VideoRecording - - -moseq_infer.VideoRecording + + +moseq_infer.VideoRecording.File + + +moseq_infer.VideoRecording.File - + + +moseq_infer.Model + + +moseq_infer.Model + + + + -moseq_infer.VideoRecording->moseq_infer.VideoRecording.File - +moseq_infer.Model->moseq_infer.InferenceTask + + + + +moseq_infer.VideoRecording + + +moseq_infer.VideoRecording + + moseq_infer.VideoRecording->moseq_infer.InferenceTask - + - + -moseq_infer.InferenceTask->moseq_infer.Inference - +moseq_infer.VideoRecording->moseq_infer.VideoRecording.File + \ No newline at end of file diff --git a/images/pipeline_moseq_infer_and_report.svg b/images/pipeline_moseq_infer_and_report.svg new file mode 100644 index 00000000..f4b92371 --- /dev/null +++ b/images/pipeline_moseq_infer_and_report.svg @@ -0,0 +1,162 @@ + + + + + +moseq_infer.InferenceTask + + +moseq_infer.InferenceTask + + + + + +moseq_infer.Inference + + +moseq_infer.Inference + + + + + +moseq_infer.InferenceTask->moseq_infer.Inference + + + + +moseq_infer.Inference.GridMoviesSampledInstances + + +moseq_infer.Inference.GridMoviesSampledInstances + + + + + +report.InferenceReport.Trajectory + + +report.InferenceReport.Trajectory + + + + + +report.InferenceReport + + +report.InferenceReport + + + + + +report.InferenceReport->report.InferenceReport.Trajectory + + + + +moseq_infer.Inference->moseq_infer.Inference.GridMoviesSampledInstances + + + + +moseq_infer.Inference->report.InferenceReport + + + + +moseq_infer.Inference.MotionSequence + + +moseq_infer.Inference.MotionSequence + + + + + +moseq_infer.Inference->moseq_infer.Inference.MotionSequence + + + + +report.PreProcessingReport + + +report.PreProcessingReport + + + + + +report.PCAReport + + +report.PCAReport + + + + + +moseq_infer.VideoRecording.File + + +moseq_infer.VideoRecording.File + + + + + +report.PreFitReport + + +report.PreFitReport + + + + + +report.FullFitReport + + +report.FullFitReport + + + + + +moseq_infer.Model + + +moseq_infer.Model + + + + + +moseq_infer.Model->moseq_infer.InferenceTask + + + + +moseq_infer.VideoRecording + + +moseq_infer.VideoRecording + + + + + +moseq_infer.VideoRecording->moseq_infer.InferenceTask + + + + +moseq_infer.VideoRecording->moseq_infer.VideoRecording.File + + + + \ No newline at end of file diff --git a/images/pipeline_moseq_train.svg b/images/pipeline_moseq_train.svg index 7be6103d..14e9a4e2 100644 --- a/images/pipeline_moseq_train.svg +++ b/images/pipeline_moseq_train.svg @@ -1,182 +1,196 @@ - - - - + + + + -moseq_train.KeypointSet.VideoFile - - -moseq_train.KeypointSet.VideoFile +moseq_train.Bodyparts + + +moseq_train.Bodyparts - - -moseq_train.PoseEstimationMethod - - -moseq_train.PoseEstimationMethod + + +moseq_train.PCATask + + +moseq_train.PCATask + + +moseq_train.Bodyparts->moseq_train.PCATask + + - + moseq_train.KeypointSet - - -moseq_train.KeypointSet + + +moseq_train.KeypointSet - - -moseq_train.PoseEstimationMethod->moseq_train.KeypointSet - + + +moseq_train.KeypointSet->moseq_train.Bodyparts + - - -moseq_train.LatentDimension - - -moseq_train.LatentDimension + + +moseq_train.KeypointSet.VideoFile + + +moseq_train.KeypointSet.VideoFile - - -moseq_train.Bodyparts - - -moseq_train.Bodyparts + + +moseq_train.KeypointSet->moseq_train.KeypointSet.VideoFile + + + + +moseq_train.FullFit + + +moseq_train.FullFit - + -moseq_train.PCATask - - -moseq_train.PCATask +moseq_train.SelectedFullFit + + +moseq_train.SelectedFullFit - - -moseq_train.Bodyparts->moseq_train.PCATask - - - - -moseq_train.PCAPrep - - -moseq_train.PCAPrep - - + + +moseq_train.FullFit->moseq_train.SelectedFullFit + - - -moseq_train.PCAFit - - -moseq_train.PCAFit + + +moseq_train.FullFitTask + + +moseq_train.FullFitTask - - -moseq_train.PCAPrep->moseq_train.PCAFit - - - - -moseq_train.KeypointSet->moseq_train.KeypointSet.VideoFile - - - + -moseq_train.KeypointSet->moseq_train.Bodyparts - +moseq_train.FullFitTask->moseq_train.FullFit + - - -moseq_train.FullFitTask - - -moseq_train.FullFitTask + + +moseq_train.LatentDimension + + +moseq_train.LatentDimension - - -moseq_train.FullFit - - -moseq_train.FullFit + + +moseq_train.PreProcessing + + +moseq_train.PreProcessing - + -moseq_train.FullFitTask->moseq_train.FullFit - +moseq_train.PCATask->moseq_train.PreProcessing + - - -moseq_train.SelectedFullFit - - -moseq_train.SelectedFullFit + + +moseq_train.PCAFit + + +moseq_train.PCAFit - + -moseq_train.FullFit->moseq_train.SelectedFullFit - +moseq_train.PCAFit->moseq_train.FullFitTask + + + + +moseq_train.PCAFit->moseq_train.LatentDimension + moseq_train.PreFitTask - - -moseq_train.PreFitTask + + +moseq_train.PreFitTask + + + + + +moseq_train.PCAFit->moseq_train.PreFitTask + + + + +moseq_train.PoseEstimationMethod + + +moseq_train.PoseEstimationMethod + + +moseq_train.PoseEstimationMethod->moseq_train.KeypointSet + + - + moseq_train.PreFit - - -moseq_train.PreFit + + +moseq_train.PreFit - + moseq_train.PreFitTask->moseq_train.PreFit - - - - -moseq_train.PCATask->moseq_train.PCAPrep - + - - -moseq_train.PCAFit->moseq_train.LatentDimension - + + +moseq_train.PreProcessing.Video + + +moseq_train.PreProcessing.Video + - - -moseq_train.PCAFit->moseq_train.FullFitTask - - + -moseq_train.PCAFit->moseq_train.PreFitTask - +moseq_train.PreProcessing->moseq_train.PCAFit + + + + +moseq_train.PreProcessing->moseq_train.PreProcessing.Video + \ No newline at end of file diff --git a/images/pipeline_moseq_train_and_report.svg b/images/pipeline_moseq_train_and_report.svg new file mode 100644 index 00000000..b79ed45d --- /dev/null +++ b/images/pipeline_moseq_train_and_report.svg @@ -0,0 +1,275 @@ + + + + + +moseq_train.KeypointSet + + +moseq_train.KeypointSet + + + + + +moseq_train.KeypointSet.VideoFile + + +moseq_train.KeypointSet.VideoFile + + + + + +moseq_train.KeypointSet->moseq_train.KeypointSet.VideoFile + + + + +moseq_train.Bodyparts + + +moseq_train.Bodyparts + + + + + +moseq_train.KeypointSet->moseq_train.Bodyparts + + + + +moseq_train.PCATask + + +moseq_train.PCATask + + + + + +moseq_train.PreProcessing + + +moseq_train.PreProcessing + + + + + +moseq_train.PCATask->moseq_train.PreProcessing + + + + +moseq_train.SelectedFullFit + + +moseq_train.SelectedFullFit + + + + + +report.PCAReport + + +report.PCAReport + + + + + +moseq_train.FullFit + + +moseq_train.FullFit + + + + + +moseq_train.FullFit->moseq_train.SelectedFullFit + + + + +report.FullFitReport + + +report.FullFitReport + + + + + +moseq_train.FullFit->report.FullFitReport + + + + +moseq_train.FullFitTask + + +moseq_train.FullFitTask + + + + + +moseq_train.FullFitTask->moseq_train.FullFit + + + + +moseq_train.PreFitTask + + +moseq_train.PreFitTask + + + + + +moseq_train.PreFit + + +moseq_train.PreFit + + + + + +moseq_train.PreFitTask->moseq_train.PreFit + + + + +moseq_train.Bodyparts->moseq_train.PCATask + + + + +moseq_train.LatentDimension + + +moseq_train.LatentDimension + + + + + +moseq_train.LatentDimension->report.PCAReport + + + + +report.InferenceReport.Trajectory + + +report.InferenceReport.Trajectory + + + + + +report.InferenceReport + + +report.InferenceReport + + + + + +report.InferenceReport->report.InferenceReport.Trajectory + + + + +moseq_train.PCAFit + + +moseq_train.PCAFit + + + + + +moseq_train.PCAFit->moseq_train.FullFitTask + + + + +moseq_train.PCAFit->moseq_train.PreFitTask + + + + +moseq_train.PCAFit->moseq_train.LatentDimension + + + + +moseq_train.PreProcessing.Video + + +moseq_train.PreProcessing.Video + + + + + +report.PreFitReport + + +report.PreFitReport + + + + + +moseq_train.PoseEstimationMethod + + +moseq_train.PoseEstimationMethod + + + + + +moseq_train.PoseEstimationMethod->moseq_train.KeypointSet + + + + +report.PreProcessingReport + + +report.PreProcessingReport + + + + + +moseq_train.PreProcessing->moseq_train.PCAFit + + + + +moseq_train.PreProcessing->moseq_train.PreProcessing.Video + + + + +moseq_train.PreProcessing->report.PreProcessingReport + + + + +moseq_train.PreFit->report.PreFitReport + + + + \ No newline at end of file From cca9f4d7f017a17a3af1783b7236d64c1b90729f Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 9 Sep 2025 14:29:12 +0100 Subject: [PATCH 35/61] update CHANGELOG --- CHANGELOG.md | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1847309f..fe7e9668 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,23 +3,25 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention. -## [0.4.0] - 2025-09-04 +## [0.4.0] - 2025-09-09 + +> **BREAKING CHANGES** - This version contains breaking changes due to keypoint-moseq upgrade and API refactoring. Please review the changes below and update your code accordingly. + ++ Feat - **BREAKING**: Upgrade keypoint-moseq from pinned 0.4.8 version to the latest version with breaking changes adding new features that are not compatible with the previous kpms versions. ++ Fix - **BREAKING**: Rename `kpms_reader` functions and add support for both `config.yml` and `config.yaml` file extensions ++ Fix - **BREAKING**: Correct generation of `kpms_dj_config.yml` and refactor `moseq_train` and `moseq_infer` to use the renamed functions. ++ Fix - **BREAKING**: Rename `PCAPrep` to `PreProcessing` + Feat - Add new attribute `outlier_scale_factor` in `PCATask` table -+ Feat - Add feature to remove outlier keypoints in `PreProcessing` table -+ Feat - Refactor `PreProcessing` table to use 3-part make function -+ Feat - Add new attribute `registered_model_name` and `registered_model_desc` in `SelectedFullFit` table -+ Feat - Add new attribute `group_label` in `VideoFile` table for downstream statistical analysis -+ Feat - Add new attributes `video_duration` in `VideoFile` table and `average_frame_rate` in `PreProcessing` table -+ Feat - Add new computation for FPS and video duration in `PreProcessing` to calculate `video_duration` and `average_frame_rate` -+ Feat - Add `sigmasq_loc` feature in `PreFit` and `FullFit` to automatically estimate sigmasq_loc (prior controlling the centroid movement across frames) -+ Fix - Fix generation of `kpms_dj_config.yml` in `kpms_reader` to use `dj_load_config` and `dj_update_config` functions -+ Fix - `moseq_train` and `moseq_infer` to use `dj_load_config` and `dj_update_config` functions -+ Fix - Rename `PCAPrep` to `PreProcessing` ++ Feat - **BREAKING**: Add feature to remove outlier keypoints in `PreProcessing` table ++ Feat - Refactor `PreProcessing` table to use 3-part make function and add a new `Video` part table ++ Feat - Add new attributes `video_duration`, `frame_rate` and `average_frame_rate`in `PreProcessing` table and add new `Video` table to store these new computations ++ Feat - **BREAKING**: Add `sigmasq_loc` feature in `PreFit` and `FullFit` to automatically estimate sigmasq_loc (prior controlling the centroid movement across frames) + Fix - Remove JAX dependencies from `pyproject.toml` -+ Fix - Update dockerfile to use Python 3.11 -+ Fix - update folder name to primary attributes instead of datetime in `PreFit` and `FullFit` ++ Fix - Update dockerfile to use Python 3.11 and upgrade dependencies ++ Fix - Update folder name to a string of combined primary attributes instead of datetime in `PreFit` and `FullFit` + Fix - Refactor `moseq_infer` paths + Add - Update docstrings ++ Add - Update Images + Add - Update pre-commit file ## [0.3.2] - 2025-08-25 From 422c2bddc0c542b0e0f58ba3a5a7ba3d1508d127 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 9 Sep 2025 18:20:59 +0100 Subject: [PATCH 36/61] updated `model_name` varchar from 100 to 1000 --- element_moseq/moseq_infer.py | 2 +- element_moseq/moseq_train.py | 12 ++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/element_moseq/moseq_infer.py b/element_moseq/moseq_infer.py index 4a19d74e..afb14251 100644 --- a/element_moseq/moseq_infer.py +++ b/element_moseq/moseq_infer.py @@ -83,7 +83,7 @@ class Model(dj.Manual): definition = """ model_id : int # Unique ID for each model --- - model_name : varchar(64) # User-friendly model name + model_name : varchar(1000) # User-friendly model name model_dir : varchar(1000)# Model directory relative to root data directory model_desc='' : varchar(1000)# Optional. User-defined description of the model -> [nullable] moseq_train.SelectedFullFit diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index 6a1c8c60..5e2d731a 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -722,7 +722,7 @@ class PreFitTask(dj.Manual): pre_kappa : int # Kappa value to use for the model pre-fitting (controls syllable duration). pre_num_iterations : int # Number of Gibbs sampling iterations to run in the model pre-fitting (typically 10-50). --- - model_name='' : varchar(100) # Name of the model to be loaded if `task_mode='load'` + model_name='' : varchar(1000) # Name of the model to be loaded if `task_mode='load'` task_mode='load' :enum('load','trigger')# 'load': load computed analysis results, 'trigger': trigger computation pre_fit_desc='' : varchar(1000) # User-defined description of the pre-fitting task """ @@ -740,14 +740,14 @@ class PreFit(dj.Computed): Attributes: PreFitTask (foreign key) : `PreFitTask` Key. - model_name (varchar) : Name of the model as "kpms_project_output_dir/model_name". + model_name (varchar) : Name of the model as "model_name". pre_fit_duration (float) : Time duration (seconds) of the model fitting computation. """ definition = """ -> PreFitTask # `PreFitTask` Key --- - model_name='' : varchar(100) # Name of the model as "kpms_project_output_dir/model_name" + model_name='' : varchar(1000) # Name of the model as "kpms_project_output_dir/model_name" pre_fit_duration=NULL : float # Time duration (seconds) of the model fitting computation """ @@ -898,7 +898,7 @@ class FullFitTask(dj.Manual): full_kappa : int # Kappa value to use for the model full fitting (typically lower than pre-fit kappa). full_num_iterations : int # Number of Gibbs sampling iterations to run in the model full fitting (typically 200-500). --- - model_name='' : varchar(100) # Name of the model to be loaded if `task_mode='load'` + model_name='' : varchar(1000) # Name of the model to be loaded if `task_mode='load'` task_mode='load' :enum('load','trigger')# Trigger or load the task full_fit_desc='' : varchar(1000) # User-defined description of the model full fitting task """ @@ -923,7 +923,7 @@ class FullFit(dj.Computed): definition = """ -> FullFitTask # `FullFitTask` Key --- - model_name='' : varchar(100) # Name of the model as "kpms_project_output_dir/model_name" + model_name='' : varchar(1000) # Name of the model as "kpms_project_output_dir/model_name" full_fit_duration=NULL : float # Time duration (seconds) of the full fitting computation """ @@ -1064,6 +1064,6 @@ class SelectedFullFit(dj.Manual): definition = """ -> FullFit --- - registered_model_name : varchar(64) # User-friendly model name + registered_model_name : varchar(1000) # User-friendly model name registered_model_desc='' : varchar(1000) # Optional user-defined description """ From 5b5d30586f2937e5845d67145db178d335fcfee8 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 9 Sep 2025 18:37:29 +0100 Subject: [PATCH 37/61] update(report): new funciton name for `load_kpms_dj_config` --- element_moseq/report.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/element_moseq/report.py b/element_moseq/report.py index b710abe9..9d20a307 100644 --- a/element_moseq/report.py +++ b/element_moseq/report.py @@ -152,7 +152,9 @@ def make(self, key): kpms_project_output_dir = ( moseq_train.get_kpms_processed_data_dir() / kpms_project_output_dir ) - kpms_dj_config = kpms_reader.dj_load_config(project_dir=kpms_project_output_dir) + kpms_dj_config = kpms_reader.load_kpms_dj_config( + project_dir=kpms_project_output_dir + ) pca = load_pca(kpms_project_output_dir.as_posix()) From cfe7de73aae7a3b4a2f075e11b18e981f87ddf64 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Wed, 10 Sep 2025 14:29:43 +0100 Subject: [PATCH 38/61] refactor(PreProcessing): remove redundancy of variables --- element_moseq/moseq_train.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index 5e2d731a..d8f44eb7 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -411,9 +411,8 @@ def make_compute( filepath_pattern=kpset_dir, format=pose_estimation_method ) - frame_rate_list = [] - video_duration_list = [] video_metadata_list = [] + frame_rates = [] for fp, video_id in zip(video_paths, video_ids): video_path = (find_full_path(get_kpms_root_data_dir(), fp)).as_posix() cap = cv2.VideoCapture(video_path) @@ -424,8 +423,7 @@ def make_compute( # Calculate duration in minutes duration_minutes = (frame_count / fps) / 60.0 - frame_rate_list.append(fps) - video_duration_list.append((video_id, int(duration_minutes))) + frame_rates.append(fps) # Get video name for the Video part table video_key = {"kpset_id": key["kpset_id"], "video_id": video_id} @@ -445,7 +443,7 @@ def make_compute( else: logger.warning(f"Video record not found for video_id {video_id}") - average_frame_rate = float(np.mean(frame_rate_list)) + average_frame_rate = float(np.mean(frame_rates)) # Generate a copy of config.yml with the generated/updated info after it is known kpms_reader.dj_generate_config( @@ -499,15 +497,15 @@ def make_compute( ) except Exception as e: - print(f"Could not create outlier plot for {recording_name}: {e}") + logger.warning( + f"Could not create outlier plot for {recording_name}: {e}" + ) return ( cleaned_coordinates, cleaned_confidences, formatted_bodyparts, average_frame_rate, - frame_rate_list, - video_duration_list, video_metadata_list, ) @@ -518,8 +516,6 @@ def make_insert( cleaned_confidences, formatted_bodyparts, average_frame_rate, - frame_rate_list, - video_duration_list, video_metadata_list, ): # Insert the main preprocessing results From a067ce9e7389ba28a5699488ad750a555c2d0e9b Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Wed, 10 Sep 2025 14:39:10 +0100 Subject: [PATCH 39/61] update docstrings --- element_moseq/moseq_train.py | 262 +++++++++++------------------------ 1 file changed, 83 insertions(+), 179 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index d8f44eb7..faf8bd26 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -33,16 +33,14 @@ def activate( Args: train_schema_name (str): A string containing the name of the `moseq_train` schema. - create_schema (bool): If True (default), schema will be created in the database. + create_schema (bool): If True (default), schema will be created in the database. create_tables (bool): If True (default), tables related to the schema will be created in the database. linking_module (str): A string containing the module name or module containing the required dependencies to activate the schema. - Dependencies: Functions: - get_kpms_root_data_dir(): Returns absolute path for root data directory/ies. - with all behavioral recordings, as (list of) string(s). + get_kpms_root_data_dir(): Returns absolute path for root data directory/ies + with all behavioral recordings, as (list of) string(s). get_kpms_processed_data_dir(): Optional. Returns absolute path for processed - data. - + data. """ if isinstance(linking_module, str): @@ -71,13 +69,14 @@ def activate( def get_kpms_root_data_dir() -> list: - """Pulls relevant func from parent namespace to specify root data dir(s). + """Fetches absolute data path to kpms data directories. + + The absolute path here is used as a reference for all downstream relative paths used in DataJoint. - It is recommended that all paths in DataJoint Elements stored as relative - paths, with respect to some user-configured "root" directory/ies. The - root(s) may vary between data modalities and user machines. Returns a full path - string or list of strings for possible root data directories. + Returns: + A list of the absolute path(s) to kpms data directories. """ + root_directories = _linking_module.get_kpms_root_data_dir() if isinstance(root_directories, (str, Path)): root_directories = [root_directories] @@ -92,11 +91,10 @@ def get_kpms_root_data_dir() -> list: def get_kpms_processed_data_dir() -> Optional[str]: - """Pulls relevant func from parent namespace. Defaults to KPMS's project /videos/. + """Retrieve the root directory for all processed data. - Method in parent namespace should provide a string to a directory where KPMS output - files will be stored. If unspecified, output files will be stored in the - session directory 'videos' folder, per Keypoint-MoSeq default. + Returns: + A string for the full path to the root directory for processed data. """ if hasattr(_linking_module, "get_kpms_processed_data_dir"): return _linking_module.get_kpms_processed_data_dir() @@ -109,7 +107,7 @@ def get_kpms_processed_data_dir() -> Optional[str]: @schema class PoseEstimationMethod(dj.Lookup): - """Pose estimation methods supported by the keypoint loader of `keypoint-moseq` package. + """Name of the pose estimation method supported by the keypoint loader of `keypoint-moseq` package. Attributes: pose_estimation_method (str): Supported pose estimation method (deeplabcut, sleap, anipose, sleap-anipose, nwb, facemap) @@ -141,13 +139,13 @@ class KeypointSet(dj.Manual): kpset_id (int) : Unique ID for each keypoint set. PoseEstimationMethod (foreign key) : Unique format method used to obtain the keypoints data. kpset_dir (str) : Path where the keypoint files are located together with the pose estimation `config` file, relative to root data directory. - kpset_desc (str) : Optional. User-entered description. + kpset_desc (str) : Optional. User-entered description. """ definition = """ kpset_id : int # Unique ID for each keypoint set --- - -> PoseEstimationMethod # Unique format method used to obtain the keypoints data + -> PoseEstimationMethod # Unique format method used to obtain the keypoints data kpset_dir : varchar(255) # Path where the keypoint files are located together with the pose estimation `config` file, relative to root data directory kpset_desc='' : varchar(1000) # Optional. User-entered description """ @@ -163,9 +161,9 @@ class VideoFile(dj.Part): definition = """ -> master - video_id : int # Unique ID for each video corresponding to each keypoint data file, relative to root data directory + video_id : int # Unique ID for each video corresponding to each keypoint data file, relative to root data directory --- - video_path : varchar(1000) # Filepath of each video from which the keypoints are derived, relative to root data directory + video_path : varchar(1000) # Filepath of each video from which the keypoints are derived, relative to root data directory """ @@ -179,7 +177,7 @@ class Bodyparts(dj.Manual): anterior_bodyparts (blob) : List of strings of anterior bodyparts posterior_bodyparts (blob) : List of strings of posterior bodyparts use_bodyparts (blob) : List of strings of bodyparts to be used - bodyparts_desc(varchar) : Optional. User-entered description. + bodyparts_desc (varchar) : Optional. User-entered description. """ definition = """ @@ -198,25 +196,19 @@ class PCATask(dj.Manual): """ Define the Principal Component Analysis (PCA) task for dimensionality reduction of keypoint data. - This table defines the parameters for the PCA preprocessing step, which is a prerequisite for both - stages of the Keypoint-MoSeq training pipeline. The PCA step reduces the dimensionality of the - keypoint data by projecting it onto the principal components that capture the most variance in - the pose dynamics. This dimensionality reduction is essential for efficient model training and - helps identify the optimal latent dimension for the subsequent AR-HMM and Keypoint-SLDS model fitting. - Attributes: Bodyparts (foreign key) : Unique ID for each `Bodyparts` key outlier_scale_factor (int) : Scale factor for outlier detection in keypoint data (default: 6) - kpms_project_output_dir (str) : Keypoint-MoSeq project output directory, relative to root data directory + kpms_project_output_dir (str) : Optional. Keypoint-MoSeq project output directory, relative to root data directory task_mode (enum) : 'load' to load existing results, 'trigger' to compute new PCA """ definition = """ - -> Bodyparts # Unique ID for each `Bodyparts` key + -> Bodyparts # Unique ID for each `Bodyparts` key --- - outlier_scale_factor=6 : int # Scale factor for outlier detection in keypoint data (default: 6) - kpms_project_output_dir='' : varchar(255) # Keypoint-MoSeq project output directory, relative to root data directory - task_mode='load' :enum('load','trigger') # 'load' to load existing results, 'trigger' to compute new PCA + outlier_scale_factor=6 : int # Scale factor for outlier detection in keypoint data (default: 6) + kpms_project_output_dir='' : varchar(255) # Optional. Keypoint-MoSeq project output directory, relative to root data directory + task_mode='load' :enum('load','trigger') # 'load' to load existing results, 'trigger' to compute new PCA """ @@ -226,20 +218,12 @@ class PreProcessing(dj.Computed): """ Preprocess keypoint data by cleaning outliers and setting up the Keypoint-MoSeq project configuration. - This table handles the initial data preprocessing step that prepares keypoint data for the PCA and - subsequent model fitting stages. It performs outlier detection and removal using medoid distance - analysis, interpolates missing keypoints, and sets up the project configuration with video paths, - bodypart specifications, and frame rate information. This preprocessing also calculates and updates - video duration metadata in the KeypointSet.VideoFile table. This preprocessing is essential for - ensuring data quality and proper model initialization. - Attributes: PCATask (foreign key) : Unique ID for each `PCATask` key. coordinates (longblob) : Dictionary mapping filenames to cleaned keypoint coordinates as ndarrays of shape (n_frames, n_bodyparts, 2[or 3]). confidences (longblob) : Dictionary mapping filenames to updated likelihood scores as ndarrays of shape (n_frames, n_bodyparts). formatted_bodyparts (longblob) : List of bodypart names. The order of the names matches the order of the bodyparts in `coordinates` and `confidences`. average_frame_rate (float) : Average frame rate of the videos for model training (used for kappa calculation). - frame_rates (longblob) : List of the frame rates of the videos for model training. """ definition = """ @@ -256,31 +240,30 @@ class Video(dj.Part): -> master video_name: varchar(255) --- - video_duration=0 : int # Duration of each video in minutes (if not provided, it will be automatically calculated in `PreProcessing`). - frame_rate=0 : float # Frame rate of the video. + video_duration : int # Duration of each video in minutes + frame_rate : float # Frame rate of the video in frames per second (Hz) """ def make_fetch(self, key): """ - Make function to: - 1. Generate and update the `kpms_dj_config.yml` with both the videoset directory and the bodyparts. - 2. Create the keypoint coordinates and confidences scores to format the data for the PCA fitting. + Preprocess keypoint data by cleaning outliers and setting up project configuration. Args: key (dict): Primary key from the `PCATask` table. Raises: - NotImplementedError: `pose_estimation_method` is only supported for `deeplabcut`. + NotImplementedError: Only `deeplabcut` pose estimation method is supported. + FileNotFoundError: No DLC config file found in `kpset_dir`. High-Level Logic: - 1. Fetches the bodyparts, format method, and the directories for the Keypoint-MoSeq project output, the keypoint set, and the video set. + 1. Fetch the bodyparts, format method, and the directories. 2. Set variables for each of the full path of the mentioned directories. 3. Find the first existing pose estimation config file in the `kpset_dir` directory, if not found, raise an error. 4. Check that the pose_estimation_method is `deeplabcut` and set up the project output directory with the default `config.yml`. 5. Create the `kpms_project_output_dir` (if it does not exist), and generates the kpms default `config.yml` with the default values from the pose estimation config. 6. Create a copy of the kpms `config.yml` named `kpms_dj_config.yml` that will be updated with both the `video_dir` and bodyparts 7. Load keypoint data from the keypoint files found in the `kpset_dir` that will serve as the training set. - 8. As a result of the keypoint loading, the coordinates and confidences scores are generated and will be used to format the data for modeling. + 8. Detect and remove outlier keypoints using medoid distance analysis, then interpolate missing values. 9. Calculate the average frame rate and the frame rate list of the videoset from which the keypoint set is derived. These two attributes can be used to calculate the kappa value. 10. Insert the results of this `make` function into the table. """ @@ -333,27 +316,6 @@ def make_compute( task_mode, outlier_scale_factor, ): - """Compute the outlier keypoints, interpolate them, and extract video metadata. - - Args: - key (dict): Primary key from the `PCATask` table. - anterior_bodyparts (list): List of anterior bodyparts. - posterior_bodyparts (list): List of posterior bodyparts. - - Raises: - NotImplementedError: `pose_estimation_method` is only supported for `deeplabcut`. - FileNotFoundError: No DLC config.(yml|yaml) found in `kpset_dir`. - Exception: Could not create outlier plot for `recording_name`. - - Logic: - 1. Load the project configuration for outlier detection - 2. Find outliers using medoid distance analysis - 3. Interpolate keypoints to fix outliers - 4. Update confidences for outlier points - 5. Calculate video frame rates and update video durations in KeypointSet.VideoFile table - 6. Store results - """ - from keypoint_moseq import ( find_medoid_distance_outliers, interpolate_keypoints, @@ -419,10 +381,7 @@ def make_compute( fps = float(cap.get(cv2.CAP_PROP_FPS)) frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) cap.release() - - # Calculate duration in minutes duration_minutes = (frame_count / fps) / 60.0 - frame_rates.append(fps) # Get video name for the Video part table @@ -479,7 +438,6 @@ def make_compute( # Update confidences for outlier points cleaned_conf = np.where(outliers["mask"], 0, raw_conf) - # Store results cleaned_coordinates[recording_name] = cleaned_coords cleaned_confidences[recording_name] = cleaned_conf @@ -518,7 +476,6 @@ def make_insert( average_frame_rate, video_metadata_list, ): - # Insert the main preprocessing results self.insert1( dict( **key, @@ -529,7 +486,6 @@ def make_insert( ) ) - # Insert video metadata into the Video part table for video_metadata in video_metadata_list: self.Video.insert1( dict( @@ -545,12 +501,6 @@ def make_insert( class PCAFit(dj.Computed): """Fit Principal Component Analysis (PCA) model for dimensionality reduction of keypoint data. - This table computes the PCA model that reduces the dimensionality of the keypoint data by projecting - it onto the principal components that capture the most variance in the pose dynamics. The fitted PCA - model is essential for both stages of the Keypoint-MoSeq training pipeline, as it provides the - low-dimensional representation of pose trajectories that will be used in the AR-HMM and Keypoint-SLDS - model fitting. - Attributes: PreProcessing (foreign key) : `PreProcessing` Key. pca_fit_time (datetime) : datetime of the PCA fitting analysis. @@ -564,19 +514,16 @@ class PCAFit(dj.Computed): def make(self, key): """ - Make function to format the keypoint data, fit the PCA model, and store it as a `pca.p` file in the Keypoint-MoSeq project output directory. + Format keypoint data and fit PCA model for dimensionality reduction. Args: key (dict): `PreProcessing` Key - Raises: - High-Level Logic: - 1. Fetch the `kpms_project_output_dir` from the `PCATask` table and define its full path. - 2. Load the `kpms_dj_config` file that contains the updated `video_dir` and bodyparts, \ - and format the keypoint data with the coordinates and confidences scores to be used in the PCA fitting. - 3. Fit the PCA model and save it as `pca.p` file in the output directory. - 4.Insert the creation datetime as the `pca_fit_time` into the table. + 1. Fetch project output directory and load configuration. + 2. Format keypoint data with coordinates and confidences. + 3. Fit PCA model and save as `pca.p` file. + 4. Insert creation datetime into table. """ from keypoint_moseq import fit_pca, format_data, save_pca @@ -610,12 +557,6 @@ class LatentDimension(dj.Imported): """ Determine the optimal latent dimension for model fitting based on variance explained by PCA components. - This table analyzes the fitted PCA model to determine the number of principal components needed to - explain a specified variance threshold (default: 90%). This analysis helps users make informed decisions - about the latent dimension parameter for both stages of the Keypoint-MoSeq training pipeline. The - recommended approach is to use the number of components that explain 90% of the variance, or a maximum - of 10 dimensions, whichever is lower. - Attributes: PCAFit (foreign key) : `PCAFit` Key. variance_percentage (float) : Variance threshold. Fixed value to 90%. @@ -633,23 +574,19 @@ class LatentDimension(dj.Imported): def make(self, key): """ - Make function to compute and store the latent dimension that explains a 90% variance threshold. + Compute and store the optimal latent dimension based on 90% variance threshold. Args: key (dict): `PCAFit` Key. Raises: + FileNotFoundError: No PCA model found in project directory. High-Level Logic: - 1. Fetches the Keypoint-MoSeq project output directory from the PCATask table and define the full path. - 2. Load the PCA model from file in this directory. - 2. Set a specified variance threshold to 90% and compute the cumulative sum of the explained variance ratio. - 3. Determine the number of components required to explain the specified variance. - 3.1 If the cumulative sum of the explained variance ratio is less than the specified variance threshold, \ - it sets the `latent_dimension` to the total number of components and `variance_percentage` to the cumulative sum of the explained variance ratio. - 3.2 If the cumulative sum of the explained variance ratio is greater than the specified variance threshold, \ - it sets the `latent_dimension` to the number of components that explain the specified variance and `variance_percentage` to the specified variance threshold. - 4. Insert the results of this `make` function into the table. + 1. Fetch project output directory and load PCA model. + 2. Calculate cumulative explained variance ratio. + 3. Determine number of components needed for 90% variance. + 4. Insert results into table. """ VARIANCE_THRESHOLD = 0.90 @@ -698,41 +635,31 @@ def make(self, key): class PreFitTask(dj.Manual): """Define parameters for Stage 1: Auto-Regressive Hidden Markov Model (AR-HMM) pre-fitting. - This table defines the parameters for the first stage of the two-stage Keypoint-MoSeq training - approach. The pre-fitting stage focuses on learning temporal behavioral dynamics and discrete - syllable states without incorporating spatial pose dynamics. This stage is designed for rapid - hyperparameter exploration, particularly for finding optimal kappa values that yield desired - syllable durations. - Attributes: PCAFit (foreign key) : `PCAFit` task. pre_latent_dim (int) : Latent dimension to use for the model pre-fitting. pre_kappa (int) : Kappa value to use for the model pre-fitting (controls syllable duration). pre_num_iterations (int) : Number of Gibbs sampling iterations to run in the model pre-fitting (typically 10-50). + model_name (varchar) : Name of the model to be loaded if `task_mode='load'` + task_mode (enum) : 'load': load computed analysis results, 'trigger': trigger computation pre_fit_desc(varchar) : User-defined description of the pre-fitting task. """ definition = """ - -> PCAFit # `PCAFit` Key - pre_latent_dim : int # Latent dimension to use for the model pre-fitting. - pre_kappa : int # Kappa value to use for the model pre-fitting (controls syllable duration). - pre_num_iterations : int # Number of Gibbs sampling iterations to run in the model pre-fitting (typically 10-50). + -> PCAFit # `PCAFit` Key + pre_latent_dim : int # Latent dimension to use for the model pre-fitting. + pre_kappa : int # Kappa value to use for the model pre-fitting (controls syllable duration). + pre_num_iterations : int # Number of Gibbs sampling iterations to run in the model pre-fitting (typically 10-50). --- model_name='' : varchar(1000) # Name of the model to be loaded if `task_mode='load'` - task_mode='load' :enum('load','trigger')# 'load': load computed analysis results, 'trigger': trigger computation - pre_fit_desc='' : varchar(1000) # User-defined description of the pre-fitting task + task_mode='load' :enum('load','trigger') # 'load': load computed analysis results, 'trigger': trigger computation + pre_fit_desc='' : varchar(1000) # User-defined description of the pre-fitting task """ @schema class PreFit(dj.Computed): - """Stage 1: Fit Auto-Regressive Hidden Markov Model (AR-HMM) for initial behavioral syllable discovery. - - This is the first stage of the two-stage Keypoint-MoSeq training approach. The PreFit step focuses - exclusively on learning the temporal dynamics and discrete behavioral states (syllables) without - incorporating spatial pose dynamics. This stage is computationally efficient and allows for rapid - exploration of hyperparameters (particularly kappa for syllable duration) before the expensive - full model fitting. + """Fit Auto-Regressive Hidden Markov Model (AR-HMM) for initial behavioral syllable discovery. Attributes: PreFitTask (foreign key) : `PreFitTask` Key. @@ -741,33 +668,28 @@ class PreFit(dj.Computed): """ definition = """ - -> PreFitTask # `PreFitTask` Key + -> PreFitTask # `PreFitTask` Key --- model_name='' : varchar(1000) # Name of the model as "kpms_project_output_dir/model_name" - pre_fit_duration=NULL : float # Time duration (seconds) of the model fitting computation + pre_fit_duration=NULL : float # Time duration (seconds) of the model fitting computation """ def make(self, key): """ - Make function to fit the AR-HMM model using the latent trajectory defined by `model['states']['x']. + Fit AR-HMM model for initial behavioral syllable discovery. Args: - key (dict) : dictionary with the `PreFitTask` Key. + key (dict): Dictionary with the `PreFitTask` Key. Raises: + FileNotFoundError: No PCA model found in project directory. - High-level Logic: - 1. Fetch the `kpms_project_output_dir` and define the full path. - 2. Fetch the model parameters from the `PreFitTask` table. - 3. Update the `dj_config.yml` with the latent dimension and kappa for the AR-HMM fitting. - 4. Load the pca model from file in the `kpms_project_output_dir`. - 5. Fetch `coordinates` and `confidences` scores to format the data for the model initialization. \ - # Data - contains the data for model fitting. \ - # Metadata - contains the recordings and start/end frames for the data. - 6. Initialize the model that create a `model` dict containing states, parameters, hyperparameters, noise prior, and random seed. - 7. Update the model dict with the selected kappa for the AR-HMM fitting. - 8. Fit the AR-HMM model using the `pre_num_iterations` and create a subdirectory in `kpms_project_output_dir` with the model's latest checkpoint file. - 9. Calculate the duration of the model fitting computation and insert it in the `PreFit` table. + High-Level Logic: + 1. Fetch project output directory and model parameters. + 2. Update configuration with latent dimension and kappa values. + 3. Load PCA model and format keypoint data. + 4. Initialize and fit AR-HMM model. + 5. Calculate fitting duration and insert results. """ from keypoint_moseq import ( fit_model, @@ -870,21 +792,15 @@ def make(self, key): @schema class FullFitTask(dj.Manual): - """Define parameters for Stage 2: Full Keypoint Switching Linear Dynamical System (Keypoint-SLDS) model fitting. - - This table defines the parameters for the second stage of the two-stage Keypoint-MoSeq training approach. - The full fitting stage refines the model by incorporating all parameters including spatial pose dynamics, - centroid, heading, noise estimates, and continuous latent states. This stage builds upon the initial - behavioral structure discovered in the pre-fitting stage to create a complete behavioral model. - - Note: The full model will generally require a lower value of kappa to yield the same target syllable - durations compared to the pre-fitting stage, as the additional spatial dynamics provide more information. + """Define parameters for FullFit step of model fitting. Attributes: PCAFit (foreign key) : `PCAFit` Key. full_latent_dim (int) : Latent dimension to use for the model full fitting. full_kappa (int) : Kappa value to use for the model full fitting (typically lower than pre-fit kappa). full_num_iterations (int) : Number of Gibbs sampling iterations to run in the model full fitting (typically 200-500). + model_name (varchar) : Name of the model to be loaded if `task_mode='load'` + task_mode (enum) : 'load': load computed analysis results, 'trigger': trigger computation full_fit_desc(varchar) : User-defined description of the model full fitting task. """ @@ -894,7 +810,7 @@ class FullFitTask(dj.Manual): full_kappa : int # Kappa value to use for the model full fitting (typically lower than pre-fit kappa). full_num_iterations : int # Number of Gibbs sampling iterations to run in the model full fitting (typically 200-500). --- - model_name='' : varchar(1000) # Name of the model to be loaded if `task_mode='load'` + model_name='' : varchar(1000) # Name of the model to be loaded if `task_mode='load'` task_mode='load' :enum('load','trigger')# Trigger or load the task full_fit_desc='' : varchar(1000) # User-defined description of the model full fitting task """ @@ -902,13 +818,7 @@ class FullFitTask(dj.Manual): @schema class FullFit(dj.Computed): - """Stage 2: Fit the complete Keypoint Switching Linear Dynamical System (Keypoint-SLDS) model. - - This is the second stage of the two-stage Keypoint-MoSeq training approach. The FullFit step refines - the model by incorporating all parameters including spatial pose dynamics, centroid, heading, noise - estimates, and continuous latent states. This stage builds upon the initial behavioral exploration - discovered in the pre-fitting stage to create a complete behavioral model with both temporal and - spatial dynamics. + """Fit the complete Keypoint Switching Linear Dynamical System (Keypoint-SLDS) model. Attributes: FullFitTask (foreign key) : `FullFitTask` Key. @@ -917,34 +827,29 @@ class FullFit(dj.Computed): """ definition = """ - -> FullFitTask # `FullFitTask` Key + -> FullFitTask # `FullFitTask` Key --- model_name='' : varchar(1000) # Name of the model as "kpms_project_output_dir/model_name" - full_fit_duration=NULL : float # Time duration (seconds) of the full fitting computation + full_fit_duration=NULL : float # Time duration (seconds) of the full fitting computation """ def make(self, key): """ - Make function to fit the full (keypoint-SLDS) model + Fit the complete Keypoint-SLDS model with spatial and temporal dynamics. Args: - key (dict): dictionary with the `FullFitTask` Key. + key (dict): Dictionary with the `FullFitTask` Key. Raises: + FileNotFoundError: No PCA model found in project directory. - High-level Logic: - 1. Fetch the `kpms_project_output_dir` and define the full path. - 2. Fetch the model parameters from the `FullFitTask` table. - 3. Update the `dj_config.yml` with the selected latent dimension and kappa for the full-fitting. - 4. Load the pca model from file in the `kpms_project_output_dir`. - 5. Fetch the `coordinates` and `confidences` scores to format the data for the model initialization. - 6. Initialize the model that create a `model` dict containing states, parameters, hyperparameters, noise prior, and random seed. - 7. Update the model dict with the selected kappa for the Keypoint-SLDS fitting. - 8. Fit the Keypoint-SLDS model using the `full_num_iterations` and create a subdirectory in `kpms_project_output_dir` with the model's latest checkpoint file. - 9. Reindex syllable labels by their frequency in the most recent model snapshot in the checkpoint file. \ - This function permutes the states and parameters of a saved checkpoint so that syllables are labeled \ - in order of frequency (i.e. so that 0 is the most frequent, 1 is the second most, and so on). - 10. Calculate the duration of the model fitting computation and insert it in the `FullFit` table. + High-Level Logic: + 1. Fetch project output directory and model parameters. + 2. Update configuration with latent dimension and kappa values. + 3. Load PCA model and format keypoint data. + 4. Initialize and fit Keypoint-SLDS model. + 5. Reindex syllable labels by frequency. + 6. Calculate fitting duration and insert results. """ from keypoint_moseq import ( estimate_sigmasq_loc, @@ -1031,8 +936,6 @@ def make(self, key): model_name=Path(model_name).parts[-1], ) - # Save of results will be applied during the Inference - else: duration_seconds = None @@ -1052,9 +955,10 @@ def make(self, key): class SelectedFullFit(dj.Manual): """Register selected FullFit models for use in the inference pipeline. - This table allows users to select and register specific FullFit models (from Stage 2 of the training - pipeline) for use in downstream inference tasks. Users can provide descriptive names and descriptions - for their trained models to facilitate model management and selection for behavioral analysis on new data. + Attributes: + FullFit (foreign key) : `FullFit` Key. + registered_model_name (varchar): User-friendly model name + registered_model_desc (varchar): Optional user-defined description """ definition = """ From 8d1ff5dc27610172ce7d503b431f4fbabc734d56 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Wed, 10 Sep 2025 14:47:41 +0100 Subject: [PATCH 40/61] docs(moseq_train) --- element_moseq/moseq_train.py | 57 +++++++++++++++++++++++------------- 1 file changed, 37 insertions(+), 20 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index faf8bd26..a16a35ac 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -246,26 +246,7 @@ class Video(dj.Part): def make_fetch(self, key): """ - Preprocess keypoint data by cleaning outliers and setting up project configuration. - - Args: - key (dict): Primary key from the `PCATask` table. - - Raises: - NotImplementedError: Only `deeplabcut` pose estimation method is supported. - FileNotFoundError: No DLC config file found in `kpset_dir`. - - High-Level Logic: - 1. Fetch the bodyparts, format method, and the directories. - 2. Set variables for each of the full path of the mentioned directories. - 3. Find the first existing pose estimation config file in the `kpset_dir` directory, if not found, raise an error. - 4. Check that the pose_estimation_method is `deeplabcut` and set up the project output directory with the default `config.yml`. - 5. Create the `kpms_project_output_dir` (if it does not exist), and generates the kpms default `config.yml` with the default values from the pose estimation config. - 6. Create a copy of the kpms `config.yml` named `kpms_dj_config.yml` that will be updated with both the `video_dir` and bodyparts - 7. Load keypoint data from the keypoint files found in the `kpset_dir` that will serve as the training set. - 8. Detect and remove outlier keypoints using medoid distance analysis, then interpolate missing values. - 9. Calculate the average frame rate and the frame rate list of the videoset from which the keypoint set is derived. These two attributes can be used to calculate the kappa value. - 10. Insert the results of this `make` function into the table. + Fetch required data for preprocessing from database tables. """ anterior_bodyparts, posterior_bodyparts, use_bodyparts = ( @@ -316,6 +297,38 @@ def make_compute( task_mode, outlier_scale_factor, ): + """ + Compute preprocessing steps including outlier removal and video metadata extraction. + + Args: + key (dict): Primary key from the `PCATask` table. + anterior_bodyparts (list): List of anterior bodyparts. + posterior_bodyparts (list): List of posterior bodyparts. + use_bodyparts (list): List of bodyparts to use. + pose_estimation_method (str): Pose estimation method (e.g., 'deeplabcut'). + kpset_dir (str): Keypoint set directory path. + video_paths (list): List of video file paths. + video_ids (list): List of video IDs. + kpms_project_output_dir (str): Project output directory path. + task_mode (str): Task mode ('load' or 'trigger'). + outlier_scale_factor (int): Scale factor for outlier detection. + + Returns: + tuple: Processed data including cleaned coordinates, confidences, and video metadata. + + Raises: + NotImplementedError: Only `deeplabcut` pose estimation method is supported. + FileNotFoundError: No DLC config file found in `kpset_dir`. + + High-Level Logic: + 1. Find the first existing pose estimation config file in the `kpset_dir` directory, if not found, raise an error. + 2. Check that the pose_estimation_method is `deeplabcut` and set up the project output directory with the default `config.yml`. + 3. Create the `kpms_project_output_dir` (if it does not exist), and generates the kpms default `config.yml` with the default values from the pose estimation config. + 4. Create a copy of the kpms `config.yml` named `kpms_dj_config.yml` that will be updated with both the `video_dir` and bodyparts + 5. Load keypoint data from the keypoint files found in the `kpset_dir` that will serve as the training set. + 6. Detect and remove outlier keypoints using medoid distance analysis, then interpolate missing values. + 7. Calculate the average frame rate and the frame rate list of the videoset from which the keypoint set is derived. These two attributes can be used to calculate the kappa value. + """ from keypoint_moseq import ( find_medoid_distance_outliers, interpolate_keypoints, @@ -476,6 +489,10 @@ def make_insert( average_frame_rate, video_metadata_list, ): + """ + Insert processed data into the PreProcessing table and Video part table. + """ + self.insert1( dict( **key, From 9e37ca9ab6759b888e29392f83de19b1e07c13a9 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Wed, 10 Sep 2025 14:49:29 +0100 Subject: [PATCH 41/61] refactor(moseq_infer): apply 3-part make function and update docstrings --- element_moseq/moseq_infer.py | 123 +++++++++++++++++++++++------------ 1 file changed, 81 insertions(+), 42 deletions(-) diff --git a/element_moseq/moseq_infer.py b/element_moseq/moseq_infer.py index afb14251..2a502290 100644 --- a/element_moseq/moseq_infer.py +++ b/element_moseq/moseq_infer.py @@ -68,25 +68,23 @@ def activate( @schema class Model(dj.Manual): - """Register a model. + """Register a trained model. Attributes: model_id (int) : Unique ID for each model. model_name (varchar) : User-friendly model name. - model_dir (varchar) : Model directory relative to root data directory (e.g. `kpms_project_output_dir/2024_03_21-00_51_39`) - latent_dim (int) : Latent dimension of the model. - kappa (float) : Kappa value of the model. - model_desc (varchar) : Optional. User-defined description of the model + model_dir (varchar) : Model directory relative to root data directory. + model_desc (varchar) : Optional. User-defined description of the model. """ definition = """ - model_id : int # Unique ID for each model + model_id : int # Unique ID for each model --- - model_name : varchar(1000) # User-friendly model name - model_dir : varchar(1000)# Model directory relative to root data directory - model_desc='' : varchar(1000)# Optional. User-defined description of the model - -> [nullable] moseq_train.SelectedFullFit + model_name : varchar(1000) # User-friendly model name + model_dir : varchar(1000) # Model directory relative to root data directory + model_desc='' : varchar(1000) # Optional. User-defined description of the model + -> [nullable] moseq_train.SelectedFullFit # Optional. FullFit key. """ @@ -135,6 +133,7 @@ class InferenceTask(dj.Manual): inference_output_dir (varchar) : Optional. Sub-directory where the results will be stored. inference_desc (varchar) : Optional. User-defined description of the inference task. num_iterations (int) : Optional. Number of iterations to use for the model inference. If null, the default number internally is 50. + task_mode (enum) : 'load': load computed analysis results, 'trigger': trigger computation """ definition = """ @@ -152,10 +151,11 @@ class InferenceTask(dj.Manual): @schema class Inference(dj.Computed): - """Infer the model from the checkpoint file and save the results as `results.h5` file. + """Infer the model from the checkpoint file and generate the results files. Attributes: InferenceTask (foreign_key) : `InferenceTask` key. + file_h5 (attach) : File path of the results.h5 file. inference_duration (float) : Time duration (seconds) of the inference computation. """ @@ -170,11 +170,13 @@ class MotionSequence(dj.Part): """Store the results of the model inference. Attributes: + InferenceTask (foreign key) : `InferenceTask` key. video_name (varchar) : Name of the video. syllable (longblob) : Syllable labels (z). The syllable label assigned to each frame (i.e. the state indexes assigned by the model). latent_state (longblob) : Inferred low-dim pose state (x). Low-dimensional representation of the animal's pose in each frame. These are similar to PCA scores, are modified to reflect the pose dynamics and noise estimates inferred by the model. centroid (longblob) : Inferred centroid (v). The centroid of the animal in each frame, as estimated by the model. heading (longblob) : Inferred heading (h). The heading of the animal in each frame, as estimated by the model. + file_csv (attach) : File path of the results.csv file. """ definition = """ @@ -203,12 +205,56 @@ class GridMoviesSampledInstances(dj.Part): instances: longblob # List of instances shown in each in grid movie (in row-major order), where each instance is specified as a tuple with the video name, start frame and end frame """ - def make(self, key): + def make_fetch(self, key): """ - This function is used to infer the model results from the checkpoint file and store the results in `MotionSequence` and `GridMoviesSampledInstances` tables. + Fetch data required for model inference. + """ + ( + keypointset_dir, + inference_output_dir, + num_iterations, + model_id, + pose_estimation_method, + task_mode, + ) = (InferenceTask & key).fetch1( + "keypointset_dir", + "inference_output_dir", + "num_iterations", + "model_id", + "pose_estimation_method", + "task_mode", + ) + + return ( + keypointset_dir, + inference_output_dir, + num_iterations, + model_id, + pose_estimation_method, + task_mode, + ) + + def make_compute( + self, + key, + keypointset_dir, + inference_output_dir, + num_iterations, + model_id, + pose_estimation_method, + task_mode, + ): + """ + Compute model inference results. Args: key (dict): `InferenceTask` primary key. + keypointset_dir (str): Directory containing keypoint data. + inference_output_dir (str): Output directory for inference results. + num_iterations (int): Number of iterations for model fitting. + model_id (int): Model ID. + pose_estimation_method (str): Pose estimation method. + task_mode (str): Task mode ('trigger' or 'load'). Raises: FileNotFoundError: If no pca model (`pca.p`) found in the parent model directory. @@ -216,19 +262,8 @@ def make(self, key): NotImplementedError: If the format method is not `deeplabcut`. FileNotFoundError: If no valid `kpms_dj_config` found in the parent model directory. - High-level Logic: - 1. Fetch the `inference_output_dir` where the results will be stored, and if it does not exist, create it. - 2. Fetch the `model_name` and the `num_iterations` from the `InferenceTask` table. - 3. Load the most recent model checkpoint and the pca model from files in the `kpms_project_output_dir`. - 4. Load the keypoint data for inference as `filepath_patterns` and format it. - 5. Initialize and apply the model with the new keypoint data. - 6. If the `num_iterations` is set, fit the model with the new keypoint data for `num_iterations` iterations; otherwise, fit the model with the default number of iterations (50). - 7. Save the results as a CSV file and store the histogram showing the frequency of each syllable. - 8. Generate and save the plots showing the median trajectory of poses associated with each given syllable. - 9. Generate and save video clips showing examples of each syllable. - 10. Generate and save the dendrogram representing distances between each syllable's median trajectory. - 11. Insert the inference duration in the `Inference` table. - 12. Insert the results in the `MotionSequence` and `GridMoviesSampledInstances` tables. + Returns: + tuple: Inference results including duration, results data, and sampled instances. """ from keypoint_moseq import ( apply_model, @@ -243,22 +278,6 @@ def make(self, key): save_results_as_csv, ) - ( - keypointset_dir, - inference_output_dir, - num_iterations, - model_id, - pose_estimation_method, - task_mode, - ) = (InferenceTask & key).fetch1( - "keypointset_dir", - "inference_output_dir", - "num_iterations", - "model_id", - "pose_estimation_method", - "task_mode", - ) - kpms_root = moseq_train.get_kpms_root_data_dir() kpms_processed = moseq_train.get_kpms_processed_data_dir() @@ -402,6 +421,24 @@ def make(self, key): duration_seconds = None + return ( + duration_seconds, + results, + sampled_instances, + inference_output_dir, + ) + + def make_insert( + self, + key, + duration_seconds, + results, + sampled_instances, + inference_output_dir, + ): + """ + Insert inference results into the database. + """ self.insert1( { **key, @@ -410,6 +447,7 @@ def make(self, key): } ) + # Insert motion sequence results for result_idx, result in results.items(): self.MotionSequence.insert1( { @@ -425,6 +463,7 @@ def make(self, key): } ) + # Insert grid movie sampled instances for syllable, sampled_instance in sampled_instances.items(): self.GridMoviesSampledInstances.insert1( {**key, "syllable": syllable, "instances": sampled_instance} From ebba3e7a289f555f70dbaa34d9772269d623ee87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Wed, 10 Sep 2025 15:22:05 +0100 Subject: [PATCH 42/61] Update element_moseq/readers/kpms_reader.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- element_moseq/readers/kpms_reader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/element_moseq/readers/kpms_reader.py b/element_moseq/readers/kpms_reader.py index 14ec2e49..3a082fc9 100644 --- a/element_moseq/readers/kpms_reader.py +++ b/element_moseq/readers/kpms_reader.py @@ -128,7 +128,7 @@ def load_kpms_dj_config( use_bps = cfg.get("use_bodyparts", []) cfg["anterior_idxs"] = jnp.array( [use_bps.index(bp) for bp in anterior] - ) # same indexing approach as upstream. :contentReference[oaicite:1]{index=1} + ) # same indexing approach as upstream. cfg["posterior_idxs"] = jnp.array([use_bps.index(bp) for bp in posterior]) if "skeleton" not in cfg or cfg["skeleton"] is None: From 7dfc2085533b75c602826def7669ac45246f0f7b Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Wed, 10 Sep 2025 15:37:13 +0100 Subject: [PATCH 43/61] udpate CHANGELOG and bump version to 1.0.0 instead and according to all the new changes --- CHANGELOG.md | 44 +++++++++++++++++++++++++--------------- element_moseq/version.py | 2 +- pyproject.toml | 2 +- 3 files changed, 30 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fe7e9668..2e9052f1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,26 +3,38 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention. -## [0.4.0] - 2025-09-09 +## [1.0.0] - 2025-09-10 > **BREAKING CHANGES** - This version contains breaking changes due to keypoint-moseq upgrade and API refactoring. Please review the changes below and update your code accordingly. -+ Feat - **BREAKING**: Upgrade keypoint-moseq from pinned 0.4.8 version to the latest version with breaking changes adding new features that are not compatible with the previous kpms versions. -+ Fix - **BREAKING**: Rename `kpms_reader` functions and add support for both `config.yml` and `config.yaml` file extensions -+ Fix - **BREAKING**: Correct generation of `kpms_dj_config.yml` and refactor `moseq_train` and `moseq_infer` to use the renamed functions. -+ Fix - **BREAKING**: Rename `PCAPrep` to `PreProcessing` -+ Feat - Add new attribute `outlier_scale_factor` in `PCATask` table -+ Feat - **BREAKING**: Add feature to remove outlier keypoints in `PreProcessing` table -+ Feat - Refactor `PreProcessing` table to use 3-part make function and add a new `Video` part table -+ Feat - Add new attributes `video_duration`, `frame_rate` and `average_frame_rate`in `PreProcessing` table and add new `Video` table to store these new computations -+ Feat - **BREAKING**: Add `sigmasq_loc` feature in `PreFit` and `FullFit` to automatically estimate sigmasq_loc (prior controlling the centroid movement across frames) +### Breaking Changes ++ **BREAKING**: Upgrade keypoint-moseq from pinned 0.4.8 version to the latest version from source with breaking changes adding new features that are not compatible with the previous kpms versions ++ **BREAKING**: Rename `kpms_reader` functions to generate, load, and update kpms config files ++ **BREAKING**: Rename `PCAPrep` to `PreProcessing` table and add new attributes ++ **BREAKING**: Add feature to remove outlier keypoints in `PreProcessing` table with new `outlier_scale_factor` attribute in `PCATask`, only available in latest version of kpms ++ **BREAKING**: Add `sigmasq_loc` feature in `PreFit` and `FullFit` to automatically estimate sigmasq_loc (prior controlling the centroid movement across frames), only available in latest version of kpms + +### New Features ++ Feat - Add support to load from both DLC `config.yml` and `config.yaml` file extensions ++ Feat - Add new `report` schema with comprehensive reporting capabilities ++ Feat - Refactor `PreProcessing` table to use 3-part make function, add a new `Video` part table, and add new attributes `video_duration`, `frame_rate` and `average_frame_rate` to store these new computations ++ Feat - Move and refactor `viz_utils` into new `plotting` module ++ Feat - Update `model_name` varchar and folder naming ++ Feat - Migrate from `setup.py` to `pyproject.toml` and `conda_env.yml` for modern Python packaging standards ++ Fix - Update devcontainer to use Python 3.11 and upgrade dependencies + Fix - Remove JAX dependencies from `pyproject.toml` -+ Fix - Update dockerfile to use Python 3.11 and upgrade dependencies -+ Fix - Update folder name to a string of combined primary attributes instead of datetime in `PreFit` and `FullFit` -+ Fix - Refactor `moseq_infer` paths -+ Add - Update docstrings -+ Add - Update Images -+ Add - Update pre-commit file ++ Fix - Update pre-commit hooks for improved linting and consistency ++ Fix - Correct generation of `kpms_dj_config.yml` and refactor `moseq_train` and `moseq_infer` to use the renamed functions ++ Fix - Refactor `moseq_infer` paths and apply 3-part make function ++ Fix - Update folder naming logic to use string of combined primary attributes instead of datetime in `PreFit` and `FullFit` ++ Fix - Improve path and directory handling using `Path` objects and robust existence checks ++ Fix - Remove redundancy of variables in `PreProcessing` table ++ Fix - Update deprecated datetime usage ++ Fix - Fix filename generation in `Inference` table ++ Fix - Bugfix in `Model` imported foreign key ++ Fix - Update tutorial_pipeline ++ Add - Update docstrings across all modules ++ Add - Update pipeline images to reflect new architecture ## [0.3.2] - 2025-08-25 + Feat - modernize packaging and environment management migrating from `setup.py` to `pyproject.toml`and `env.yml` diff --git a/element_moseq/version.py b/element_moseq/version.py index afa728c9..0ec000f6 100644 --- a/element_moseq/version.py +++ b/element_moseq/version.py @@ -2,4 +2,4 @@ Package metadata """ -__version__ = "0.4.0" +__version__ = "1.0.0" diff --git a/pyproject.toml b/pyproject.toml index 7d5cb059..f4a38cf9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "element-moseq" -version = "0.4.0" +version = "1.0.0" description = "Keypoint-MoSeq DataJoint Element" readme = "README.md" license = {text = "MIT"} From b34404b8bbbe9584ecc023698f962857b72c4d87 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Wed, 10 Sep 2025 15:39:55 +0100 Subject: [PATCH 44/61] minor refactor in PreProcessing --- element_moseq/moseq_train.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index a16a35ac..9fafa273 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -264,11 +264,9 @@ def make_fetch(self, key): "video_path", "video_id" ) - kpms_project_output_dir, task_mode = (PCATask & key).fetch1( - "kpms_project_output_dir", "task_mode" - ) - - outlier_scale_factor = (PCATask & key).fetch1("outlier_scale_factor") + kpms_project_output_dir, task_mode, outlier_scale_factor = ( + PCATask & key + ).fetch1("kpms_project_output_dir", "task_mode", "outlier_scale_factor") return ( anterior_bodyparts, From e5f279b5d24c8845f35c52351956b5a8128e30c0 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Wed, 10 Sep 2025 15:48:18 +0100 Subject: [PATCH 45/61] update docstrings in kpms_reader --- element_moseq/readers/kpms_reader.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/element_moseq/readers/kpms_reader.py b/element_moseq/readers/kpms_reader.py index 3a082fc9..59b8bc4e 100644 --- a/element_moseq/readers/kpms_reader.py +++ b/element_moseq/readers/kpms_reader.py @@ -83,12 +83,8 @@ def dj_generate_config(project_dir: str, **kwargs) -> str: ) with open(base_cfg_path, "r") as f: cfg = yaml.safe_load(f) or {} - - # Upstream uses shallow updates for top-level keys in generate_config. - # We follow that (simple `dict.update`); nested blocks can be passed explicitly. cfg.update(kwargs) - # Upstream ensures skeleton exists; we do the same. if "skeleton" not in cfg or cfg["skeleton"] is None: cfg["skeleton"] = [] @@ -118,17 +114,13 @@ def load_kpms_dj_config( cfg = yaml.safe_load(f) or {} if check_if_valid: - _check_config_validity( - cfg - ) # readthedocs source mirrors this logic. :contentReference[oaicite:0]{index=0} + _check_config_validity(cfg) if build_indexes: anterior = cfg.get("anterior_bodyparts", []) posterior = cfg.get("posterior_bodyparts", []) use_bps = cfg.get("use_bodyparts", []) - cfg["anterior_idxs"] = jnp.array( - [use_bps.index(bp) for bp in anterior] - ) # same indexing approach as upstream. + cfg["anterior_idxs"] = jnp.array([use_bps.index(bp) for bp in anterior]) cfg["posterior_idxs"] = jnp.array([use_bps.index(bp) for bp in posterior]) if "skeleton" not in cfg or cfg["skeleton"] is None: From d7f503bd252da64460e88bb91962942197bf63dd Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Wed, 10 Sep 2025 17:50:44 +0100 Subject: [PATCH 46/61] update CHANGELOG --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e9052f1..810def3b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,7 +25,7 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and + Fix - Remove JAX dependencies from `pyproject.toml` + Fix - Update pre-commit hooks for improved linting and consistency + Fix - Correct generation of `kpms_dj_config.yml` and refactor `moseq_train` and `moseq_infer` to use the renamed functions -+ Fix - Refactor `moseq_infer` paths and apply 3-part make function ++ Fix - Refactor `moseq_infer` and `moseq_train`, and implement a three-part make function in the most resource-intensive functions + Fix - Update folder naming logic to use string of combined primary attributes instead of datetime in `PreFit` and `FullFit` + Fix - Improve path and directory handling using `Path` objects and robust existence checks + Fix - Remove redundancy of variables in `PreProcessing` table From 2603729089643252bd94d50584933ff18d9a20d4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Tue, 16 Sep 2025 00:34:03 +0100 Subject: [PATCH 47/61] Update element_moseq/report.py Co-authored-by: Thinh Nguyen --- element_moseq/report.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/element_moseq/report.py b/element_moseq/report.py index 9d20a307..f7a3ff50 100644 --- a/element_moseq/report.py +++ b/element_moseq/report.py @@ -206,7 +206,7 @@ def make(self, key): moseq_train.get_kpms_processed_data_dir(), prefit_model_name ) prefit_output_dir = Path(prefit_model_dir) / "fitting_progress.pdf" - if os.path.exists(prefit_output_dir): + prefit_output_dir.exists(): self.insert1({**key, "fitting_progress_pdf": prefit_output_dir}) else: raise FileNotFoundError( From ec6ad58cdb0412f31e77f5294c20cc56192397ee Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 16 Sep 2025 01:12:52 +0100 Subject: [PATCH 48/61] from `report` to `moseq_report` and refactor path exists --- element_moseq/{report.py => moseq_report.py} | 16 ++++++++-------- element_moseq/moseq_train.py | 2 +- 2 files changed, 9 insertions(+), 9 deletions(-) rename element_moseq/{report.py => moseq_report.py} (96%) diff --git a/element_moseq/report.py b/element_moseq/moseq_report.py similarity index 96% rename from element_moseq/report.py rename to element_moseq/moseq_report.py index 9d20a307..9e730fc2 100644 --- a/element_moseq/report.py +++ b/element_moseq/moseq_report.py @@ -205,12 +205,12 @@ def make(self, key): prefit_model_dir = find_full_path( moseq_train.get_kpms_processed_data_dir(), prefit_model_name ) - prefit_output_dir = Path(prefit_model_dir) / "fitting_progress.pdf" - if os.path.exists(prefit_output_dir): - self.insert1({**key, "fitting_progress_pdf": prefit_output_dir}) + prefit_output_file = Path(prefit_model_dir) / "fitting_progress.pdf" + if prefit_model_dir.exists(): + self.insert1({**key, "fitting_progress_pdf": prefit_output_file}) else: raise FileNotFoundError( - f"PreFit fitting_progress.pdf not found at {prefit_output_dir}" + f"PreFit fitting_progress.pdf not found at {prefit_output_file}" ) @@ -227,12 +227,12 @@ def make(self, key): fullfit_model_dir = find_full_path( moseq_train.get_kpms_processed_data_dir(), fullfit_model_name ) - fullfit_output_dir = Path(fullfit_model_dir) / "fitting_progress.pdf" - if os.path.exists(fullfit_output_dir): - self.insert1({**key, "fitting_progress_pdf": fullfit_output_dir}) + fullfit_output_file = Path(fullfit_model_dir) / "fitting_progress.pdf" + if fullfit_model_dir.exists(): + self.insert1({**key, "fitting_progress_pdf": fullfit_output_file}) else: raise FileNotFoundError( - f"FullFit fitting_progress.pdf not found at {fullfit_output_dir}" + f"FullFit fitting_progress.pdf not found at {fullfit_output_file}" ) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index 9fafa273..84d5261f 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -356,7 +356,7 @@ def make_compute( from .readers.kpms_reader import _base_config_path cfg_path = _base_config_path(kpset_dir) - if not os.path.exists(cfg_path): + if not cfg_path.exists(): raise FileNotFoundError( f"No DLC config.(yml|yaml) found in {kpset_dir}" ) From 70210fa268bfad91b7e21c06e96de69728d7f433 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 16 Sep 2025 01:13:47 +0100 Subject: [PATCH 49/61] reafactor `inference` --- element_moseq/moseq_infer.py | 122 ++++++++++++++++++----------------- 1 file changed, 64 insertions(+), 58 deletions(-) diff --git a/element_moseq/moseq_infer.py b/element_moseq/moseq_infer.py index 2a502290..aeb7951d 100644 --- a/element_moseq/moseq_infer.py +++ b/element_moseq/moseq_infer.py @@ -267,17 +267,28 @@ def make_compute( """ from keypoint_moseq import ( apply_model, + filter_centroids_headings, format_data, generate_grid_movies, generate_trajectory_plots, + get_syllable_instances, load_checkpoint, load_keypoints, load_pca, + load_results, plot_similarity_dendrogram, plot_syllable_frequencies, + sample_instances, save_results_as_csv, ) + # Constants used by default as in kpms + DEFAULT_NUM_ITERS = 50 + FILTER_SIZE = 9 + MIN_DURATION = 3 + MIN_FREQUENCY = 0.005 + GRID_SAMPLES = 4 * 6 # minimum rows * cols + kpms_root = moseq_train.get_kpms_root_data_dir() kpms_processed = moseq_train.get_kpms_processed_data_dir() @@ -314,11 +325,6 @@ def make_compute( coordinates, confidences, _ = load_keypoints( filepath_pattern=keypointset_dir, format=pose_estimation_method ) - else: - raise NotImplementedError( - "The currently supported format method is `deeplabcut`. If you require \ - support for another format method, please reach out to us at `support@datajoint.com`." - ) kpms_dj_config = kpms_reader.load_kpms_dj_config(model_dir.parent) @@ -340,8 +346,7 @@ def make_compute( model_name=Path(model_dir).name, results_path=(inference_output_dir / "results.h5").as_posix(), return_model=False, - num_iters=num_iterations - or 50, # default internal value in the keypoint-moseq function + num_iters=num_iterations or DEFAULT_NUM_ITERS, **kpms_dj_config, ) end_time = datetime.now(timezone.utc) @@ -381,12 +386,6 @@ def make_compute( ) else: - from keypoint_moseq import ( - filter_centroids_headings, - get_syllable_instances, - load_results, - sample_instances, - ) # load results results = load_results( @@ -394,37 +393,60 @@ def make_compute( model_name=inference_output_dir.parts[-1], ) - # extract syllables from results - syllables = {k: v["syllable"] for k, v in results.items()} + # extract syllables from results + syllables = {k: v["syllable"] for k, v in results.items()} - # extract and smooth centroids and headings - centroids = {k: v["centroid"] for k, v in results.items()} - headings = {k: v["heading"] for k, v in results.items()} + # extract and smooth centroids and headings + centroids = {k: v["centroid"] for k, v in results.items()} + headings = {k: v["heading"] for k, v in results.items()} - filter_size = 9 # default value - centroids, headings = filter_centroids_headings( - centroids, headings, filter_size=filter_size - ) + centroids, headings = filter_centroids_headings( + centroids, headings, filter_size=FILTER_SIZE + ) - # extract sample instances for each syllable - syllable_instances = get_syllable_instances( - syllables, min_duration=3, min_frequency=0.005 - ) - # Map each syllable to a list of its sampled events. - sampled_instances = sample_instances( - syllable_instances=syllable_instances, - num_samples=4 * 6, # minimum rows * cols - coordinates=coordinates, - centroids=centroids, - headings=headings, + # extract sample instances for each syllable + syllable_instances = get_syllable_instances( + syllables, min_duration=MIN_DURATION, min_frequency=MIN_FREQUENCY + ) + # Map each syllable to a list of its sampled events. + sampled_instances = sample_instances( + syllable_instances=syllable_instances, + num_samples=GRID_SAMPLES, + coordinates=coordinates, + centroids=centroids, + headings=headings, + ) + + duration_seconds = None + + # Prepare motion sequence data + motion_sequence_data = [] + for result_idx, result in results.items(): + motion_sequence_data.append( + { + **key, + "video_name": result_idx, + "syllable": result["syllable"], + "latent_state": result["latent_state"], + "centroid": result["centroid"], + "heading": result["heading"], + "file_csv": ( + inference_output_dir / "results_as_csv" / f"{result_idx}.csv" + ).as_posix(), + } ) - duration_seconds = None + # Prepare grid movie data + grid_movie_data = [] + for syllable, sampled_instance in sampled_instances.items(): + grid_movie_data.append( + {**key, "syllable": syllable, "instances": sampled_instance} + ) return ( duration_seconds, - results, - sampled_instances, + motion_sequence_data, + grid_movie_data, inference_output_dir, ) @@ -432,8 +454,8 @@ def make_insert( self, key, duration_seconds, - results, - sampled_instances, + motion_sequence_data, + grid_movie_data, inference_output_dir, ): """ @@ -447,24 +469,8 @@ def make_insert( } ) - # Insert motion sequence results - for result_idx, result in results.items(): - self.MotionSequence.insert1( - { - **key, - "video_name": result_idx, - "syllable": result["syllable"], - "latent_state": result["latent_state"], - "centroid": result["centroid"], - "heading": result["heading"], - "file_csv": ( - inference_output_dir / "results_as_csv" / f"{result_idx}.csv" - ).as_posix(), - } - ) + # Add key to each motion sequence record and insert + self.MotionSequence.insert(motion_sequence_data) - # Insert grid movie sampled instances - for syllable, sampled_instance in sampled_instances.items(): - self.GridMoviesSampledInstances.insert1( - {**key, "syllable": syllable, "instances": sampled_instance} - ) + # Add key to each grid movie record and insert + self.GridMoviesSampledInstances.insert(grid_movie_data) From e837f34d4c6d729ea73d97bfd825367df756fc55 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 16 Sep 2025 01:14:36 +0100 Subject: [PATCH 50/61] update CHANGELOG --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 810def3b..9feb0a95 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,9 +14,9 @@ Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and + **BREAKING**: Add feature to remove outlier keypoints in `PreProcessing` table with new `outlier_scale_factor` attribute in `PCATask`, only available in latest version of kpms + **BREAKING**: Add `sigmasq_loc` feature in `PreFit` and `FullFit` to automatically estimate sigmasq_loc (prior controlling the centroid movement across frames), only available in latest version of kpms -### New Features +### New Features and Fixes + Feat - Add support to load from both DLC `config.yml` and `config.yaml` file extensions -+ Feat - Add new `report` schema with comprehensive reporting capabilities ++ Feat - Add new `mosesq_report` schema with comprehensive reporting capabilities + Feat - Refactor `PreProcessing` table to use 3-part make function, add a new `Video` part table, and add new attributes `video_duration`, `frame_rate` and `average_frame_rate` to store these new computations + Feat - Move and refactor `viz_utils` into new `plotting` module + Feat - Update `model_name` varchar and folder naming From 684ebe609e20a0e5e120e648df8024c7cf995b2d Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 16 Sep 2025 01:32:30 +0100 Subject: [PATCH 51/61] update tutorial_pipeline --- notebooks/tutorial_pipeline.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/notebooks/tutorial_pipeline.py b/notebooks/tutorial_pipeline.py index ffae002d..b9635f78 100644 --- a/notebooks/tutorial_pipeline.py +++ b/notebooks/tutorial_pipeline.py @@ -4,7 +4,7 @@ from element_lab import lab from element_animal import subject from element_session import session_with_datetime as session -from element_moseq import moseq_train, moseq_infer, report +from element_moseq import moseq_train, moseq_infer, moseq_report from element_animal.subject import Subject from element_lab.lab import Source, Lab, Protocol, User, Project @@ -38,7 +38,7 @@ def get_kpms_processed_data_dir() -> str: return None -__all__ = ["lab", "subject", "session", "moseq_train", "moseq_infer", "Device"] +__all__ = ["lab", "subject", "session", "moseq_train", "moseq_infer", "moseq_report", "Device"] # Activate schemas ------------- @@ -80,4 +80,4 @@ class Device(dj.Lookup): moseq_train.activate(db_prefix + "moseq_train", linking_module=__name__) moseq_infer.activate(db_prefix + "moseq_infer", linking_module=__name__) -report.activate(db_prefix + "report", linking_module=__name__) +moseq_report.activate(db_prefix + "moseq_report", linking_module=__name__) From 2954d51cb24adf469d209a9d738d6784ee7c569f Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 16 Sep 2025 01:33:51 +0100 Subject: [PATCH 52/61] black formatting in `tutorial_pipeline` --- notebooks/tutorial_pipeline.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/notebooks/tutorial_pipeline.py b/notebooks/tutorial_pipeline.py index b9635f78..0f86f32e 100644 --- a/notebooks/tutorial_pipeline.py +++ b/notebooks/tutorial_pipeline.py @@ -38,7 +38,15 @@ def get_kpms_processed_data_dir() -> str: return None -__all__ = ["lab", "subject", "session", "moseq_train", "moseq_infer", "moseq_report", "Device"] +__all__ = [ + "lab", + "subject", + "session", + "moseq_train", + "moseq_infer", + "moseq_report", + "Device", +] # Activate schemas ------------- From ddf2eea71aa230726f00041682d46cc3f5717734 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 16 Sep 2025 16:35:00 +0100 Subject: [PATCH 53/61] refactor(inference): from `insert` to `insert1` in a loop to prevent timeout error --- element_moseq/moseq_infer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/element_moseq/moseq_infer.py b/element_moseq/moseq_infer.py index aeb7951d..63ab4974 100644 --- a/element_moseq/moseq_infer.py +++ b/element_moseq/moseq_infer.py @@ -469,8 +469,8 @@ def make_insert( } ) - # Add key to each motion sequence record and insert - self.MotionSequence.insert(motion_sequence_data) + for motion_record in motion_sequence_data: + self.MotionSequence.insert1(motion_record) - # Add key to each grid movie record and insert - self.GridMoviesSampledInstances.insert(grid_movie_data) + for grid_record in grid_movie_data: + self.GridMoviesSampledInstances.insert1(grid_record) From 971fd3819627ce95ea0c4c7463f994dc51372a47 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 16 Sep 2025 16:54:27 +0100 Subject: [PATCH 54/61] refactor(inference): rename attributes `file_h5` and `file_csv` --- element_moseq/moseq_infer.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/element_moseq/moseq_infer.py b/element_moseq/moseq_infer.py index 63ab4974..cfbb4649 100644 --- a/element_moseq/moseq_infer.py +++ b/element_moseq/moseq_infer.py @@ -151,19 +151,19 @@ class InferenceTask(dj.Manual): @schema class Inference(dj.Computed): - """Infer the model from the checkpoint file and generate the results files. + """Infer the model from the checkpoint file and generate the results of segmenting continuous behavior into discrete syllables. Attributes: - InferenceTask (foreign_key) : `InferenceTask` key. - file_h5 (attach) : File path of the results.h5 file. - inference_duration (float) : Time duration (seconds) of the inference computation. + InferenceTask (foreign_key) : `InferenceTask` key. + syllable_segmentation_file (attach) : File path of the syllable analysis results (HDF5 format) containing syllable labels, latent states, centroids, and headings. + inference_duration (float) : Time duration (seconds) of the inference computation. """ definition = """ - -> InferenceTask # `InferenceTask` key + -> InferenceTask # `InferenceTask` key --- - file_h5 : attach # File path of the results.h5 file - inference_duration=NULL : float # Time duration (seconds) of the inference computation + syllable_segmentation_file : attach # File path of the syllable analysis results (HDF5 format) containing syllable labels, latent states, centroids, and headings + inference_duration=NULL : float # Time duration (seconds) of the inference computation """ class MotionSequence(dj.Part): @@ -176,7 +176,7 @@ class MotionSequence(dj.Part): latent_state (longblob) : Inferred low-dim pose state (x). Low-dimensional representation of the animal's pose in each frame. These are similar to PCA scores, are modified to reflect the pose dynamics and noise estimates inferred by the model. centroid (longblob) : Inferred centroid (v). The centroid of the animal in each frame, as estimated by the model. heading (longblob) : Inferred heading (h). The heading of the animal in each frame, as estimated by the model. - file_csv (attach) : File path of the results.csv file. + motion_sequence_file (attach) : File path of the motion sequence data (CSV format). """ definition = """ @@ -187,7 +187,7 @@ class MotionSequence(dj.Part): latent_state : longblob # Inferred low-dim pose state (x). Low-dimensional representation of the animal's pose in each frame. These are similar to PCA scores, are modified to reflect the pose dynamics and noise estimates inferred by the model centroid : longblob # Inferred centroid (v). The centroid of the animal in each frame, as estimated by the model heading : longblob # Inferred heading (h). The heading of the animal in each frame, as estimated by the model - file_csv : attach # File path of the results.csv file + motion_sequence_file: attach # File path of the temporal sequence of motion sequence data (CSV format) """ class GridMoviesSampledInstances(dj.Part): @@ -430,7 +430,7 @@ def make_compute( "latent_state": result["latent_state"], "centroid": result["centroid"], "heading": result["heading"], - "file_csv": ( + "motion_sequence_file": ( inference_output_dir / "results_as_csv" / f"{result_idx}.csv" ).as_posix(), } @@ -465,7 +465,9 @@ def make_insert( { **key, "inference_duration": duration_seconds, - "file_h5": (inference_output_dir / "results.h5").as_posix(), + "syllable_segmentation_file": ( + inference_output_dir / "results.h5" + ).as_posix(), } ) From 5c5b8e3efb976e6b6acd46a3485a0371d59e3da3 Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 16 Sep 2025 17:00:37 +0100 Subject: [PATCH 55/61] update docstring in inference --- element_moseq/moseq_infer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/element_moseq/moseq_infer.py b/element_moseq/moseq_infer.py index cfbb4649..8a3ddcee 100644 --- a/element_moseq/moseq_infer.py +++ b/element_moseq/moseq_infer.py @@ -176,7 +176,7 @@ class MotionSequence(dj.Part): latent_state (longblob) : Inferred low-dim pose state (x). Low-dimensional representation of the animal's pose in each frame. These are similar to PCA scores, are modified to reflect the pose dynamics and noise estimates inferred by the model. centroid (longblob) : Inferred centroid (v). The centroid of the animal in each frame, as estimated by the model. heading (longblob) : Inferred heading (h). The heading of the animal in each frame, as estimated by the model. - motion_sequence_file (attach) : File path of the motion sequence data (CSV format). + motion_sequence_file (attach) : File path of the temporal sequence of motion data (CSV format). """ definition = """ @@ -187,7 +187,7 @@ class MotionSequence(dj.Part): latent_state : longblob # Inferred low-dim pose state (x). Low-dimensional representation of the animal's pose in each frame. These are similar to PCA scores, are modified to reflect the pose dynamics and noise estimates inferred by the model centroid : longblob # Inferred centroid (v). The centroid of the animal in each frame, as estimated by the model heading : longblob # Inferred heading (h). The heading of the animal in each frame, as estimated by the model - motion_sequence_file: attach # File path of the temporal sequence of motion sequence data (CSV format) + motion_sequence_file: attach # File path of the temporal sequence of motion data (CSV format) """ class GridMoviesSampledInstances(dj.Part): From 8864be256ecb6541b3cf22bc104ef88ab6b4884f Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 16 Sep 2025 17:13:00 +0100 Subject: [PATCH 56/61] update `images` --- images/pipeline.svg | 574 ++++++++++----------- images/pipeline_moseq_infer.svg | 56 +- images/pipeline_moseq_infer_and_report.svg | 162 ------ images/pipeline_moseq_train.svg | 266 +++++----- images/pipeline_moseq_train_and_report.svg | 275 ---------- 5 files changed, 448 insertions(+), 885 deletions(-) delete mode 100644 images/pipeline_moseq_infer_and_report.svg delete mode 100644 images/pipeline_moseq_train_and_report.svg diff --git a/images/pipeline.svg b/images/pipeline.svg index 61ade29c..e1096ef0 100644 --- a/images/pipeline.svg +++ b/images/pipeline.svg @@ -1,425 +1,425 @@ - + - - + + -moseq_train.KeypointSet - - -moseq_train.KeypointSet +moseq_train.LatentDimension + + +moseq_train.LatentDimension - - -moseq_train.KeypointSet.VideoFile - - -moseq_train.KeypointSet.VideoFile + + +moseq_report.PCAReport + + +moseq_report.PCAReport - + -moseq_train.KeypointSet->moseq_train.KeypointSet.VideoFile - +moseq_train.LatentDimension->moseq_report.PCAReport + - - -moseq_train.Bodyparts - - -moseq_train.Bodyparts + + +moseq_infer.Inference.GridMoviesSampledInstances + + +moseq_infer.Inference.GridMoviesSampledInstances - - -moseq_train.KeypointSet->moseq_train.Bodyparts - + + +moseq_infer.VideoRecording.File + + +moseq_infer.VideoRecording.File + - - -moseq_train.PCATask - - -moseq_train.PCATask + + + +Device + + +Device - - -moseq_train.PreProcessing - - -moseq_train.PreProcessing + + +moseq_infer.VideoRecording + + +moseq_infer.VideoRecording - - -moseq_train.PCATask->moseq_train.PreProcessing - + + +Device->moseq_infer.VideoRecording + - - -moseq_train.SelectedFullFit - - -moseq_train.SelectedFullFit + + +moseq_train.PreProcessing + + +moseq_train.PreProcessing - - -moseq_infer.Model - - -moseq_infer.Model + + +moseq_report.PreProcessingReport + + +moseq_report.PreProcessingReport - - -moseq_train.SelectedFullFit->moseq_infer.Model - + + +moseq_train.PreProcessing->moseq_report.PreProcessingReport + - - -report.PCAReport - - -report.PCAReport + + +moseq_train.PCAFit + + +moseq_train.PCAFit - - -moseq_infer.VideoRecording.File - - -moseq_infer.VideoRecording.File - - + + +moseq_train.PreProcessing->moseq_train.PCAFit + - - -moseq_train.FullFit - - -moseq_train.FullFit + + +moseq_train.PreProcessing.Video + + +moseq_train.PreProcessing.Video - + -moseq_train.FullFit->moseq_train.SelectedFullFit - +moseq_train.PreProcessing->moseq_train.PreProcessing.Video + - - -report.FullFitReport - - -report.FullFitReport + + +moseq_train.PreFitTask + + +moseq_train.PreFitTask - - -moseq_train.FullFit->report.FullFitReport - - - - -moseq_train.FullFitTask - - -moseq_train.FullFitTask + + +moseq_train.PreFit + + +moseq_train.PreFit - - -moseq_train.FullFitTask->moseq_train.FullFit - + + +moseq_train.PreFitTask->moseq_train.PreFit + - + moseq_infer.Inference - - -moseq_infer.Inference + + +moseq_infer.Inference + + +moseq_infer.Inference->moseq_infer.Inference.GridMoviesSampledInstances + + - + moseq_infer.Inference.MotionSequence - - -moseq_infer.Inference.MotionSequence + + +moseq_infer.Inference.MotionSequence moseq_infer.Inference->moseq_infer.Inference.MotionSequence - + - - -report.InferenceReport - - -report.InferenceReport + + +moseq_report.InferenceReport + + +moseq_report.InferenceReport - + -moseq_infer.Inference->report.InferenceReport - +moseq_infer.Inference->moseq_report.InferenceReport + - - -moseq_infer.Inference.GridMoviesSampledInstances - - -moseq_infer.Inference.GridMoviesSampledInstances + + +moseq_report.FullFitReport + + +moseq_report.FullFitReport - + + +moseq_train.SelectedFullFit + + +moseq_train.SelectedFullFit + + + + + +moseq_infer.Model + + +moseq_infer.Model + + + + -moseq_infer.Inference->moseq_infer.Inference.GridMoviesSampledInstances - +moseq_train.SelectedFullFit->moseq_infer.Model + - - -moseq_train.PreFitTask - - -moseq_train.PreFitTask + + +subject.Subject + + +subject.Subject - - -moseq_train.PreFit - - -moseq_train.PreFit + + +session.Session + + +session.Session - + -moseq_train.PreFitTask->moseq_train.PreFit - +subject.Subject->session.Session + + + + +moseq_report.InferenceReport.Trajectory + + +moseq_report.InferenceReport.Trajectory + + - + moseq_infer.InferenceTask - - -moseq_infer.InferenceTask + + +moseq_infer.InferenceTask moseq_infer.Model->moseq_infer.InferenceTask - + - + -moseq_infer.VideoRecording - - -moseq_infer.VideoRecording +moseq_train.KeypointSet.VideoFile + + +moseq_train.KeypointSet.VideoFile - - -moseq_infer.VideoRecording->moseq_infer.VideoRecording.File - + + +moseq_train.PoseEstimationMethod + + +moseq_train.PoseEstimationMethod + - - -moseq_infer.VideoRecording->moseq_infer.InferenceTask - - - -moseq_train.Bodyparts->moseq_train.PCATask - + + +moseq_train.PoseEstimationMethod->moseq_infer.InferenceTask + - - -session.Session - - -session.Session + + +moseq_train.KeypointSet + + +moseq_train.KeypointSet - - -session.Session->moseq_infer.VideoRecording - + + +moseq_train.PoseEstimationMethod->moseq_train.KeypointSet + - - -moseq_train.LatentDimension - - -moseq_train.LatentDimension + + +moseq_train.FullFitTask + + +moseq_train.FullFitTask - - -moseq_train.LatentDimension->report.PCAReport - - - - -report.InferenceReport.Trajectory - - -report.InferenceReport.Trajectory + + +moseq_train.FullFit + + +moseq_train.FullFit - - -report.InferenceReport->report.InferenceReport.Trajectory - + + +moseq_train.FullFitTask->moseq_train.FullFit + - - -moseq_train.PCAFit - - -moseq_train.PCAFit + + +moseq_report.InferenceReport->moseq_report.InferenceReport.Trajectory + + + + +moseq_infer.InferenceTask->moseq_infer.Inference + + + + +moseq_report.PreFitReport + + +moseq_report.PreFitReport - + + +moseq_train.PreFit->moseq_report.PreFitReport + + + -moseq_train.PCAFit->moseq_train.FullFitTask - +moseq_train.PCAFit->moseq_train.LatentDimension + moseq_train.PCAFit->moseq_train.PreFitTask - + - + -moseq_train.PCAFit->moseq_train.LatentDimension - - - - -subject.Subject - - -subject.Subject - - +moseq_train.PCAFit->moseq_train.FullFitTask + - + -subject.Subject->session.Session - - - - -moseq_train.PreProcessing.Video - - -moseq_train.PreProcessing.Video - - - - - -report.PreFitReport - - -report.PreFitReport - - +moseq_infer.VideoRecording->moseq_infer.VideoRecording.File + - + -moseq_infer.InferenceTask->moseq_infer.Inference - - - - -Device - - -Device - - +moseq_infer.VideoRecording->moseq_infer.InferenceTask + - + -Device->moseq_infer.VideoRecording - +session.Session->moseq_infer.VideoRecording + - - -moseq_train.PoseEstimationMethod - - -moseq_train.PoseEstimationMethod - + + +moseq_train.KeypointSet->moseq_train.KeypointSet.VideoFile + + + +moseq_train.Bodyparts + + +moseq_train.Bodyparts + - - -moseq_train.PoseEstimationMethod->moseq_train.KeypointSet - - + -moseq_train.PoseEstimationMethod->moseq_infer.InferenceTask - +moseq_train.KeypointSet->moseq_train.Bodyparts + - - -report.PreProcessingReport - - -report.PreProcessingReport + + +moseq_train.PCATask + + +moseq_train.PCATask - + -moseq_train.PreProcessing->moseq_train.PCAFit - +moseq_train.Bodyparts->moseq_train.PCATask + - + -moseq_train.PreProcessing->moseq_train.PreProcessing.Video - +moseq_train.PCATask->moseq_train.PreProcessing + - + -moseq_train.PreProcessing->report.PreProcessingReport - +moseq_train.FullFit->moseq_report.FullFitReport + - + -moseq_train.PreFit->report.PreFitReport - +moseq_train.FullFit->moseq_train.SelectedFullFit + \ No newline at end of file diff --git a/images/pipeline_moseq_infer.svg b/images/pipeline_moseq_infer.svg index bc3d7103..8c9050f9 100644 --- a/images/pipeline_moseq_infer.svg +++ b/images/pipeline_moseq_infer.svg @@ -11,9 +11,9 @@ - + moseq_infer.Inference - + moseq_infer.Inference @@ -33,52 +33,52 @@ - - -moseq_infer.Inference->moseq_infer.Inference.GridMoviesSampledInstances - - - - -moseq_infer.Inference.MotionSequence - - -moseq_infer.Inference.MotionSequence - - - - - -moseq_infer.Inference->moseq_infer.Inference.MotionSequence - - - + moseq_infer.VideoRecording.File - + moseq_infer.VideoRecording.File - + moseq_infer.Model - + moseq_infer.Model - + moseq_infer.Model->moseq_infer.InferenceTask - + + +moseq_infer.Inference->moseq_infer.Inference.GridMoviesSampledInstances + + + +moseq_infer.Inference.MotionSequence + + +moseq_infer.Inference.MotionSequence + + + + + +moseq_infer.Inference->moseq_infer.Inference.MotionSequence + + + + moseq_infer.VideoRecording - + moseq_infer.VideoRecording diff --git a/images/pipeline_moseq_infer_and_report.svg b/images/pipeline_moseq_infer_and_report.svg deleted file mode 100644 index f4b92371..00000000 --- a/images/pipeline_moseq_infer_and_report.svg +++ /dev/null @@ -1,162 +0,0 @@ - - - - - -moseq_infer.InferenceTask - - -moseq_infer.InferenceTask - - - - - -moseq_infer.Inference - - -moseq_infer.Inference - - - - - -moseq_infer.InferenceTask->moseq_infer.Inference - - - - -moseq_infer.Inference.GridMoviesSampledInstances - - -moseq_infer.Inference.GridMoviesSampledInstances - - - - - -report.InferenceReport.Trajectory - - -report.InferenceReport.Trajectory - - - - - -report.InferenceReport - - -report.InferenceReport - - - - - -report.InferenceReport->report.InferenceReport.Trajectory - - - - -moseq_infer.Inference->moseq_infer.Inference.GridMoviesSampledInstances - - - - -moseq_infer.Inference->report.InferenceReport - - - - -moseq_infer.Inference.MotionSequence - - -moseq_infer.Inference.MotionSequence - - - - - -moseq_infer.Inference->moseq_infer.Inference.MotionSequence - - - - -report.PreProcessingReport - - -report.PreProcessingReport - - - - - -report.PCAReport - - -report.PCAReport - - - - - -moseq_infer.VideoRecording.File - - -moseq_infer.VideoRecording.File - - - - - -report.PreFitReport - - -report.PreFitReport - - - - - -report.FullFitReport - - -report.FullFitReport - - - - - -moseq_infer.Model - - -moseq_infer.Model - - - - - -moseq_infer.Model->moseq_infer.InferenceTask - - - - -moseq_infer.VideoRecording - - -moseq_infer.VideoRecording - - - - - -moseq_infer.VideoRecording->moseq_infer.InferenceTask - - - - -moseq_infer.VideoRecording->moseq_infer.VideoRecording.File - - - - \ No newline at end of file diff --git a/images/pipeline_moseq_train.svg b/images/pipeline_moseq_train.svg index 14e9a4e2..eeb20963 100644 --- a/images/pipeline_moseq_train.svg +++ b/images/pipeline_moseq_train.svg @@ -1,196 +1,196 @@ - + - - + + -moseq_train.Bodyparts - - -moseq_train.Bodyparts +moseq_train.LatentDimension + + +moseq_train.LatentDimension - - -moseq_train.PCATask - - -moseq_train.PCATask + + +moseq_train.PreFit + + +moseq_train.PreFit - - -moseq_train.Bodyparts->moseq_train.PCATask - - - - -moseq_train.KeypointSet - - -moseq_train.KeypointSet + + +moseq_train.PreProcessing + + +moseq_train.PreProcessing - - -moseq_train.KeypointSet->moseq_train.Bodyparts - - - - -moseq_train.KeypointSet.VideoFile - - -moseq_train.KeypointSet.VideoFile + + +moseq_train.PCAFit + + +moseq_train.PCAFit - - -moseq_train.KeypointSet->moseq_train.KeypointSet.VideoFile - + + +moseq_train.PreProcessing->moseq_train.PCAFit + - - -moseq_train.FullFit - - -moseq_train.FullFit + + +moseq_train.PreProcessing.Video + + +moseq_train.PreProcessing.Video - - -moseq_train.SelectedFullFit - - -moseq_train.SelectedFullFit + + +moseq_train.PreProcessing->moseq_train.PreProcessing.Video + + + + +moseq_train.PreFitTask + + +moseq_train.PreFitTask - + + +moseq_train.PreFitTask->moseq_train.PreFit + + + -moseq_train.FullFit->moseq_train.SelectedFullFit - +moseq_train.PCAFit->moseq_train.LatentDimension + + + + +moseq_train.PCAFit->moseq_train.PreFitTask + - + moseq_train.FullFitTask - - -moseq_train.FullFitTask + + +moseq_train.FullFitTask - - -moseq_train.FullFitTask->moseq_train.FullFit - + + +moseq_train.PCAFit->moseq_train.FullFitTask + - - -moseq_train.LatentDimension - - -moseq_train.LatentDimension + + +moseq_train.Bodyparts + + +moseq_train.Bodyparts - - -moseq_train.PreProcessing - - -moseq_train.PreProcessing + + +moseq_train.PCATask + + +moseq_train.PCATask - - -moseq_train.PCATask->moseq_train.PreProcessing - + + +moseq_train.Bodyparts->moseq_train.PCATask + - - -moseq_train.PCAFit - - -moseq_train.PCAFit + + +moseq_train.KeypointSet.VideoFile + + +moseq_train.KeypointSet.VideoFile - - -moseq_train.PCAFit->moseq_train.FullFitTask - - - + -moseq_train.PCAFit->moseq_train.LatentDimension - +moseq_train.PCATask->moseq_train.PreProcessing + - - -moseq_train.PreFitTask - - -moseq_train.PreFitTask + + +moseq_train.FullFit + + +moseq_train.FullFit - + + +moseq_train.SelectedFullFit + + +moseq_train.SelectedFullFit + + + + -moseq_train.PCAFit->moseq_train.PreFitTask - +moseq_train.FullFit->moseq_train.SelectedFullFit + - + moseq_train.PoseEstimationMethod - - -moseq_train.PoseEstimationMethod + + +moseq_train.PoseEstimationMethod + + + + + +moseq_train.KeypointSet + + +moseq_train.KeypointSet moseq_train.PoseEstimationMethod->moseq_train.KeypointSet - - - - -moseq_train.PreFit - - -moseq_train.PreFit - + - - + -moseq_train.PreFitTask->moseq_train.PreFit - - - - -moseq_train.PreProcessing.Video - - -moseq_train.PreProcessing.Video - - +moseq_train.KeypointSet->moseq_train.Bodyparts + - + -moseq_train.PreProcessing->moseq_train.PCAFit - +moseq_train.KeypointSet->moseq_train.KeypointSet.VideoFile + - + -moseq_train.PreProcessing->moseq_train.PreProcessing.Video - +moseq_train.FullFitTask->moseq_train.FullFit + \ No newline at end of file diff --git a/images/pipeline_moseq_train_and_report.svg b/images/pipeline_moseq_train_and_report.svg deleted file mode 100644 index b79ed45d..00000000 --- a/images/pipeline_moseq_train_and_report.svg +++ /dev/null @@ -1,275 +0,0 @@ - - - - - -moseq_train.KeypointSet - - -moseq_train.KeypointSet - - - - - -moseq_train.KeypointSet.VideoFile - - -moseq_train.KeypointSet.VideoFile - - - - - -moseq_train.KeypointSet->moseq_train.KeypointSet.VideoFile - - - - -moseq_train.Bodyparts - - -moseq_train.Bodyparts - - - - - -moseq_train.KeypointSet->moseq_train.Bodyparts - - - - -moseq_train.PCATask - - -moseq_train.PCATask - - - - - -moseq_train.PreProcessing - - -moseq_train.PreProcessing - - - - - -moseq_train.PCATask->moseq_train.PreProcessing - - - - -moseq_train.SelectedFullFit - - -moseq_train.SelectedFullFit - - - - - -report.PCAReport - - -report.PCAReport - - - - - -moseq_train.FullFit - - -moseq_train.FullFit - - - - - -moseq_train.FullFit->moseq_train.SelectedFullFit - - - - -report.FullFitReport - - -report.FullFitReport - - - - - -moseq_train.FullFit->report.FullFitReport - - - - -moseq_train.FullFitTask - - -moseq_train.FullFitTask - - - - - -moseq_train.FullFitTask->moseq_train.FullFit - - - - -moseq_train.PreFitTask - - -moseq_train.PreFitTask - - - - - -moseq_train.PreFit - - -moseq_train.PreFit - - - - - -moseq_train.PreFitTask->moseq_train.PreFit - - - - -moseq_train.Bodyparts->moseq_train.PCATask - - - - -moseq_train.LatentDimension - - -moseq_train.LatentDimension - - - - - -moseq_train.LatentDimension->report.PCAReport - - - - -report.InferenceReport.Trajectory - - -report.InferenceReport.Trajectory - - - - - -report.InferenceReport - - -report.InferenceReport - - - - - -report.InferenceReport->report.InferenceReport.Trajectory - - - - -moseq_train.PCAFit - - -moseq_train.PCAFit - - - - - -moseq_train.PCAFit->moseq_train.FullFitTask - - - - -moseq_train.PCAFit->moseq_train.PreFitTask - - - - -moseq_train.PCAFit->moseq_train.LatentDimension - - - - -moseq_train.PreProcessing.Video - - -moseq_train.PreProcessing.Video - - - - - -report.PreFitReport - - -report.PreFitReport - - - - - -moseq_train.PoseEstimationMethod - - -moseq_train.PoseEstimationMethod - - - - - -moseq_train.PoseEstimationMethod->moseq_train.KeypointSet - - - - -report.PreProcessingReport - - -report.PreProcessingReport - - - - - -moseq_train.PreProcessing->moseq_train.PCAFit - - - - -moseq_train.PreProcessing->moseq_train.PreProcessing.Video - - - - -moseq_train.PreProcessing->report.PreProcessingReport - - - - -moseq_train.PreFit->report.PreFitReport - - - - \ No newline at end of file From 4cb957cc2d2576fb9553b0b8ef372d90de58ee8d Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 16 Sep 2025 17:25:25 +0100 Subject: [PATCH 57/61] minor fix --- element_moseq/moseq_train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index 84d5261f..a62e0c11 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -356,11 +356,11 @@ def make_compute( from .readers.kpms_reader import _base_config_path cfg_path = _base_config_path(kpset_dir) - if not cfg_path.exists(): + cfg = Path(cfg_path) + if not cfg.exists(): raise FileNotFoundError( f"No DLC config.(yml|yaml) found in {kpset_dir}" ) - cfg = Path(cfg_path) setup_project( project_dir=kpms_project_output_dir.as_posix(), deeplabcut_config=cfg.as_posix(), From 547e12497a15164a24ae987ad2641714051b37ca Mon Sep 17 00:00:00 2001 From: MilagrosMarin Date: Tue, 16 Sep 2025 19:53:44 +0100 Subject: [PATCH 58/61] fix(inference): add `overwrite=True` in `apply_model` to better handle re-running a session --- element_moseq/moseq_infer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/element_moseq/moseq_infer.py b/element_moseq/moseq_infer.py index 8a3ddcee..0919dc45 100644 --- a/element_moseq/moseq_infer.py +++ b/element_moseq/moseq_infer.py @@ -347,6 +347,7 @@ def make_compute( results_path=(inference_output_dir / "results.h5").as_posix(), return_model=False, num_iters=num_iterations or DEFAULT_NUM_ITERS, + overwrite=True, **kpms_dj_config, ) end_time = datetime.now(timezone.utc) From 468b8179e6f597430c250d536b1c1ce0877bb092 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Tue, 16 Sep 2025 21:16:36 +0100 Subject: [PATCH 59/61] Update element_moseq/moseq_train.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- element_moseq/moseq_train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/element_moseq/moseq_train.py b/element_moseq/moseq_train.py index a62e0c11..192a7e66 100644 --- a/element_moseq/moseq_train.py +++ b/element_moseq/moseq_train.py @@ -230,7 +230,7 @@ class PreProcessing(dj.Computed): -> PCATask # Unique ID for each `PCATask` key --- coordinates : longblob # Dictionary mapping filenames to keypoint coordinates as ndarrays of shape (n_frames, n_bodyparts, 2[or 3]) - confidences : longblob # Dictionary mapping filenames to `likelihood` scores as ndarrays of shape (n_frames, n_bodyparts) + confidences : longblob # Dictionary mapping filenames to likelihood scores as ndarrays of shape (n_frames, n_bodyparts) formatted_bodyparts : longblob # List of bodypart names. The order of the names matches the order of the bodyparts in `coordinates` and `confidences`. average_frame_rate : float # Average frame rate of the videos for model training (used for kappa calculation). """ From 01e9fe298ff6095e0353509773be5befd2fc392c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Tue, 16 Sep 2025 21:20:20 +0100 Subject: [PATCH 60/61] Update element_moseq/moseq_report.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- element_moseq/moseq_report.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/element_moseq/moseq_report.py b/element_moseq/moseq_report.py index e6f520e1..9d20632b 100644 --- a/element_moseq/moseq_report.py +++ b/element_moseq/moseq_report.py @@ -206,7 +206,7 @@ def make(self, key): moseq_train.get_kpms_processed_data_dir(), prefit_model_name ) prefit_output_dir = Path(prefit_model_dir) / "fitting_progress.pdf" - if prefit_model_dir.exists(): + if prefit_output_dir.exists(): self.insert1({**key, "fitting_progress_pdf": prefit_output_dir}) else: raise FileNotFoundError( From c2f7bbcbe54c74dbadd413ad04b96dff16606388 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Milagros=20Mar=C3=ADn?= Date: Tue, 16 Sep 2025 21:20:30 +0100 Subject: [PATCH 61/61] Update element_moseq/moseq_report.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- element_moseq/moseq_report.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/element_moseq/moseq_report.py b/element_moseq/moseq_report.py index 9d20632b..4b346b6c 100644 --- a/element_moseq/moseq_report.py +++ b/element_moseq/moseq_report.py @@ -228,7 +228,7 @@ def make(self, key): moseq_train.get_kpms_processed_data_dir(), fullfit_model_name ) fullfit_output_file = Path(fullfit_model_dir) / "fitting_progress.pdf" - if fullfit_model_dir.exists(): + if fullfit_output_file.exists(): self.insert1({**key, "fitting_progress_pdf": fullfit_output_file}) else: raise FileNotFoundError(