Skip to content
Merged
Changes from 1 commit
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
778d763
feat: several fixes to config yml overwrite and add feat
MilagrosMarin Aug 26, 2025
558b6f5
Update CHANGELOG and bump version
MilagrosMarin Aug 26, 2025
442a718
update CHANGELOG
MilagrosMarin Aug 26, 2025
fc68fa0
refactor:`Inference` paths
MilagrosMarin Sep 4, 2025
e10b771
feat(VideoFile): add new attributes for downstream statistics
MilagrosMarin Sep 4, 2025
2ded339
feat(PCATask): add new attribute and update docstrings
MilagrosMarin Sep 4, 2025
cbe4944
docs(PreProcessing): update docstrings
MilagrosMarin Sep 4, 2025
aa28e67
feat(PreProcessing): add computation of FPS and video duration
MilagrosMarin Sep 4, 2025
3982816
docs(PCAFit): update docstring
MilagrosMarin Sep 4, 2025
57d5117
docs(LatentDimension): udpate docstring
MilagrosMarin Sep 4, 2025
3e0c60a
docs(PreFitTask, PreFit): update docstrings
MilagrosMarin Sep 4, 2025
5bbdd4f
feat(PreFit): add `estimate_sigmasq_loc` from the latest KPMS version
MilagrosMarin Sep 4, 2025
185e78d
fix(PreFit): change folder name to match primary attributes instead o…
MilagrosMarin Sep 4, 2025
5a0c254
docs(FullFitTask): update docstring
MilagrosMarin Sep 4, 2025
742249d
docs(FullFit): udpate docstrings
MilagrosMarin Sep 4, 2025
76015c3
feat(FullFit): add `estimate_sigmasq_loc` from the latest version of …
MilagrosMarin Sep 4, 2025
b343e1c
fix(FullFit): update folder name to primary attributes instead of dat…
MilagrosMarin Sep 4, 2025
b2d2013
docs(SelectedFullFit): update docstring
MilagrosMarin Sep 4, 2025
bbd3a01
fix(devcontainer): update Dockerfile, update python version for KPMS …
MilagrosMarin Sep 4, 2025
9bcf9a3
feat(pyproject): update version of KPMS from 0.4.8 to the latest vers…
MilagrosMarin Sep 4, 2025
e8e51f1
update CHANGELOG
MilagrosMarin Sep 4, 2025
317c51a
update CHANGELOG
MilagrosMarin Sep 4, 2025
911546f
fix(pyproject): update version
MilagrosMarin Sep 4, 2025
3f782b7
Minor update to attribute defaults
MilagrosMarin Sep 4, 2025
0684a29
fix model_name_str
MilagrosMarin Sep 4, 2025
0650b36
review: update filename to `conda_env.yml`
MilagrosMarin Sep 9, 2025
cc50492
feat: add `report` schema
MilagrosMarin Sep 9, 2025
6961f1e
review: apply suggestions to `moseq_train`
MilagrosMarin Sep 9, 2025
fcd29ee
refactor(moseq_infer)
MilagrosMarin Sep 9, 2025
5d8f036
feat(plotting): move and refactor `viz_utils`
MilagrosMarin Sep 9, 2025
0259b3a
refactor(kpms_reader)
MilagrosMarin Sep 9, 2025
9e8bb3e
update(tutorial_pipeline)
MilagrosMarin Sep 9, 2025
25b427f
minor fix in viz_utils
MilagrosMarin Sep 9, 2025
0a46a63
update images
MilagrosMarin Sep 9, 2025
cca9f4d
update CHANGELOG
MilagrosMarin Sep 9, 2025
422c2bd
updated `model_name` varchar from 100 to 1000
MilagrosMarin Sep 9, 2025
5b5d305
update(report): new funciton name for `load_kpms_dj_config`
MilagrosMarin Sep 9, 2025
cfe7de7
refactor(PreProcessing): remove redundancy of variables
MilagrosMarin Sep 10, 2025
a067ce9
update docstrings
MilagrosMarin Sep 10, 2025
8d1ff5d
docs(moseq_train)
MilagrosMarin Sep 10, 2025
9e37ca9
refactor(moseq_infer): apply 3-part make function and update docstrings
MilagrosMarin Sep 10, 2025
ebba3e7
Update element_moseq/readers/kpms_reader.py
MilagrosMarin Sep 10, 2025
7dfc208
udpate CHANGELOG and bump version to 1.0.0 instead and according to a…
MilagrosMarin Sep 10, 2025
b34404b
minor refactor in PreProcessing
MilagrosMarin Sep 10, 2025
e5f279b
update docstrings in kpms_reader
MilagrosMarin Sep 10, 2025
d7f503b
update CHANGELOG
MilagrosMarin Sep 10, 2025
2603729
Update element_moseq/report.py
MilagrosMarin Sep 15, 2025
ec6ad58
from `report` to `moseq_report` and refactor path exists
MilagrosMarin Sep 16, 2025
70210fa
reafactor `inference`
MilagrosMarin Sep 16, 2025
e837f34
update CHANGELOG
MilagrosMarin Sep 16, 2025
4dc17e4
Merge remote-tracking branch 'refs/remotes/origin/feat_outliers_remov…
MilagrosMarin Sep 16, 2025
684ebe6
update tutorial_pipeline
MilagrosMarin Sep 16, 2025
2954d51
black formatting in `tutorial_pipeline`
MilagrosMarin Sep 16, 2025
ddf2eea
refactor(inference): from `insert` to `insert1` in a loop to prevent …
MilagrosMarin Sep 16, 2025
971fd38
refactor(inference): rename attributes `file_h5` and `file_csv`
MilagrosMarin Sep 16, 2025
5c5b8e3
update docstring in inference
MilagrosMarin Sep 16, 2025
8864be2
update `images`
MilagrosMarin Sep 16, 2025
4cb957c
minor fix
MilagrosMarin Sep 16, 2025
547e124
fix(inference): add `overwrite=True` in `apply_model` to better handl…
MilagrosMarin Sep 16, 2025
468b817
Update element_moseq/moseq_train.py
MilagrosMarin Sep 16, 2025
01e9fe2
Update element_moseq/moseq_report.py
MilagrosMarin Sep 16, 2025
c2f7bbc
Update element_moseq/moseq_report.py
MilagrosMarin Sep 16, 2025
59d5728
Merge branch 'datajoint:main' into feat_outliers_removal
MilagrosMarin Sep 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 64 additions & 58 deletions element_moseq/moseq_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,17 +267,28 @@ def make_compute(
"""
from keypoint_moseq import (
apply_model,
filter_centroids_headings,
format_data,
generate_grid_movies,
generate_trajectory_plots,
get_syllable_instances,
load_checkpoint,
load_keypoints,
load_pca,
load_results,
plot_similarity_dendrogram,
plot_syllable_frequencies,
sample_instances,
save_results_as_csv,
)

# Constants used by default as in kpms
DEFAULT_NUM_ITERS = 50
FILTER_SIZE = 9
MIN_DURATION = 3
MIN_FREQUENCY = 0.005
GRID_SAMPLES = 4 * 6 # minimum rows * cols

kpms_root = moseq_train.get_kpms_root_data_dir()
kpms_processed = moseq_train.get_kpms_processed_data_dir()

Expand Down Expand Up @@ -314,11 +325,6 @@ def make_compute(
coordinates, confidences, _ = load_keypoints(
filepath_pattern=keypointset_dir, format=pose_estimation_method
)
else:
raise NotImplementedError(
"The currently supported format method is `deeplabcut`. If you require \
support for another format method, please reach out to us at `support@datajoint.com`."
)

kpms_dj_config = kpms_reader.load_kpms_dj_config(model_dir.parent)

Expand All @@ -340,8 +346,7 @@ def make_compute(
model_name=Path(model_dir).name,
results_path=(inference_output_dir / "results.h5").as_posix(),
return_model=False,
num_iters=num_iterations
or 50, # default internal value in the keypoint-moseq function
num_iters=num_iterations or DEFAULT_NUM_ITERS,
**kpms_dj_config,
)
end_time = datetime.now(timezone.utc)
Expand Down Expand Up @@ -381,59 +386,76 @@ def make_compute(
)

else:
from keypoint_moseq import (
filter_centroids_headings,
get_syllable_instances,
load_results,
sample_instances,
)

# load results
results = load_results(
project_dir=inference_output_dir.parent,
model_name=inference_output_dir.parts[-1],
)

# extract syllables from results
syllables = {k: v["syllable"] for k, v in results.items()}
# extract syllables from results
syllables = {k: v["syllable"] for k, v in results.items()}

# extract and smooth centroids and headings
centroids = {k: v["centroid"] for k, v in results.items()}
headings = {k: v["heading"] for k, v in results.items()}
# extract and smooth centroids and headings
centroids = {k: v["centroid"] for k, v in results.items()}
headings = {k: v["heading"] for k, v in results.items()}

filter_size = 9 # default value
centroids, headings = filter_centroids_headings(
centroids, headings, filter_size=filter_size
)
centroids, headings = filter_centroids_headings(
centroids, headings, filter_size=FILTER_SIZE
)

# extract sample instances for each syllable
syllable_instances = get_syllable_instances(
syllables, min_duration=3, min_frequency=0.005
)
# Map each syllable to a list of its sampled events.
sampled_instances = sample_instances(
syllable_instances=syllable_instances,
num_samples=4 * 6, # minimum rows * cols
coordinates=coordinates,
centroids=centroids,
headings=headings,
# extract sample instances for each syllable
syllable_instances = get_syllable_instances(
syllables, min_duration=MIN_DURATION, min_frequency=MIN_FREQUENCY
)
# Map each syllable to a list of its sampled events.
sampled_instances = sample_instances(
syllable_instances=syllable_instances,
num_samples=GRID_SAMPLES,
coordinates=coordinates,
centroids=centroids,
headings=headings,
)

duration_seconds = None

# Prepare motion sequence data
motion_sequence_data = []
for result_idx, result in results.items():
motion_sequence_data.append(
{
**key,
"video_name": result_idx,
"syllable": result["syllable"],
"latent_state": result["latent_state"],
"centroid": result["centroid"],
"heading": result["heading"],
"file_csv": (
inference_output_dir / "results_as_csv" / f"{result_idx}.csv"
).as_posix(),
}
)

duration_seconds = None
# Prepare grid movie data
grid_movie_data = []
for syllable, sampled_instance in sampled_instances.items():
grid_movie_data.append(
{**key, "syllable": syllable, "instances": sampled_instance}
)

return (
duration_seconds,
results,
sampled_instances,
motion_sequence_data,
grid_movie_data,
inference_output_dir,
)

def make_insert(
self,
key,
duration_seconds,
results,
sampled_instances,
motion_sequence_data,
grid_movie_data,
inference_output_dir,
):
"""
Expand All @@ -447,24 +469,8 @@ def make_insert(
}
)

# Insert motion sequence results
for result_idx, result in results.items():
self.MotionSequence.insert1(
{
**key,
"video_name": result_idx,
"syllable": result["syllable"],
"latent_state": result["latent_state"],
"centroid": result["centroid"],
"heading": result["heading"],
"file_csv": (
inference_output_dir / "results_as_csv" / f"{result_idx}.csv"
).as_posix(),
}
)
# Add key to each motion sequence record and insert
self.MotionSequence.insert(motion_sequence_data)

# Insert grid movie sampled instances
for syllable, sampled_instance in sampled_instances.items():
self.GridMoviesSampledInstances.insert1(
{**key, "syllable": syllable, "instances": sampled_instance}
)
# Add key to each grid movie record and insert
self.GridMoviesSampledInstances.insert(grid_movie_data)