Skip to content

Commit 117694e

Browse files
authored
Merge pull request #121 from ttngu207/main
add new table `LabeledVideo` to generate/store labeled video data after PoseEstimation
2 parents 2833f15 + 8fff7b0 commit 117694e

File tree

3 files changed

+115
-2
lines changed

3 files changed

+115
-2
lines changed

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,15 @@
33
Observes [Semantic Versioning](https://semver.org/spec/v2.0.0.html) standard and
44
[Keep a Changelog](https://keepachangelog.com/en/1.0.0/) convention.
55

6+
7+
## [0.3.1] - 2024-08-16
8+
9+
+ Add - add new table `LabeledVideo` to generate/store labeled video data after PoseEstimation
10+
11+
## [0.3.0] - 2024-08-08
12+
13+
+ Add - add support for inference (PoseEstimation) using pytorch model
14+
615
## [0.2.14] - 2024-08-02
716

817
+ Fix - improve imports, avoid circular dependencies

element_deeplabcut/model.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from .readers import dlc_reader
2020

2121
schema = dj.schema()
22+
logger = dj.logger
23+
2224
_linking_module = None
2325

2426

@@ -733,10 +735,16 @@ def make(self, key):
733735
find_full_path(get_dlc_root_data_dir(), fp).as_posix()
734736
for fp in video_relpaths
735737
]
736-
analyze_video_params = (PoseEstimationTask & key).fetch1(
738+
pose_estimation_params = (PoseEstimationTask & key).fetch1(
737739
"pose_estimation_params"
738740
) or {}
739741

742+
# expect a nested dictionary with "analyze_videos" params
743+
# if not, assume "pose_estimation_params" as a flat dictionary that include relevant "analyze_videos" params
744+
analyze_video_params = (
745+
pose_estimation_params.get("analyze_videos") or pose_estimation_params
746+
)
747+
740748
@memoized_result(
741749
uniqueness_dict={
742750
**analyze_video_params,
@@ -867,6 +875,102 @@ def get_trajectory(cls, key: dict, body_parts: list = "all") -> pd.DataFrame:
867875
return df
868876

869877

878+
@schema
879+
class LabeledVideo(dj.Computed):
880+
definition = """
881+
-> PoseEstimation
882+
"""
883+
884+
class File(dj.Part):
885+
definition = """
886+
-> master
887+
-> VideoRecording.File
888+
---
889+
labeled_video_path: varchar(255) # relative path to labeled video
890+
"""
891+
892+
@property
893+
def key_source(self):
894+
return PoseEstimation & RecordingInfo
895+
896+
def make(self, key):
897+
import deeplabcut
898+
899+
pose_estimation_params = (PoseEstimationTask & key).fetch1(
900+
"pose_estimation_params"
901+
) or {}
902+
903+
# expect a nested dictionary with "create_labeled_video" and "extract_outlier_frames" params
904+
# if not, assume "pose_estimation_params" as a flat dictionary
905+
create_labeled_video_params = (
906+
pose_estimation_params.get("create_labeled_video") or pose_estimation_params
907+
)
908+
909+
outputframerate = create_labeled_video_params.pop(
910+
"outputframerate", 5
911+
) # final labeled video FPS defaults to 5 Hz
912+
913+
dlc_model_ = (Model & key).fetch1()
914+
fps, nframes = (RecordingInfo & key).fetch1("fps", "nframes")
915+
output_dir = (PoseEstimationTask & key).fetch1("pose_estimation_output_dir")
916+
output_dir = find_full_path(get_dlc_root_data_dir(), output_dir)
917+
918+
project_path = find_full_path(
919+
get_dlc_root_data_dir(), dlc_model_["project_path"]
920+
)
921+
922+
try:
923+
dlc_config = next(output_dir.glob("dj_dlc_config*.yaml"))
924+
dlc_config = project_path / dlc_config.name
925+
assert dlc_config.exists()
926+
except (StopIteration, AssertionError):
927+
dlc_config = next(project_path.glob("dj_dlc_config*.yaml"))
928+
logger.warning(
929+
f"No dj_dlc_config*.yaml file found in {output_dir} - this is unexpected.\nUsing {dlc_config}"
930+
)
931+
932+
entries = []
933+
for vkey in (VideoRecording.File & key).fetch("KEY"):
934+
video_file = (VideoRecording.File & vkey).fetch1("file_path")
935+
video_file = find_full_path(get_dlc_root_data_dir(), video_file)
936+
937+
# -- create labeled video --
938+
create_labeled_video_kwargs = {
939+
k: v
940+
for k, v in create_labeled_video_params.items()
941+
if k in inspect.signature(deeplabcut.create_labeled_video).parameters
942+
}
943+
create_labeled_video_kwargs.update(
944+
dict(
945+
config=dlc_config.as_posix(),
946+
videos=[video_file.as_posix()],
947+
shuffle=dlc_model_["shuffle"],
948+
trainingsetindex=dlc_model_["trainingsetindex"],
949+
modelprefix=dlc_model_["model_prefix"],
950+
destfolder=output_dir.as_posix(),
951+
Frames2plot=np.arange(0, nframes, int(fps / outputframerate)),
952+
outputframerate=outputframerate,
953+
)
954+
)
955+
deeplabcut.create_labeled_video(**create_labeled_video_kwargs)
956+
957+
labeled_video_path = next(
958+
output_dir.glob(f"{video_file.stem}*_labeled.mp4")
959+
)
960+
entries.append(
961+
{
962+
**key,
963+
**vkey,
964+
"labeled_video_path": labeled_video_path.relative_to(
965+
get_dlc_processed_data_dir()
966+
).as_posix(),
967+
}
968+
)
969+
970+
self.insert1(key)
971+
self.File.insert(entries)
972+
973+
870974
def str_to_bool(value) -> bool:
871975
"""Return whether the provided string represents true. Otherwise false.
872976

element_deeplabcut/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,4 @@
22
Package metadata
33
"""
44

5-
__version__ = "0.3.0"
5+
__version__ = "0.3.1"

0 commit comments

Comments
 (0)