|
19 | 19 | from .readers import dlc_reader |
20 | 20 |
|
21 | 21 | schema = dj.schema() |
| 22 | +logger = dj.logger |
| 23 | + |
22 | 24 | _linking_module = None |
23 | 25 |
|
24 | 26 |
|
@@ -733,10 +735,16 @@ def make(self, key): |
733 | 735 | find_full_path(get_dlc_root_data_dir(), fp).as_posix() |
734 | 736 | for fp in video_relpaths |
735 | 737 | ] |
736 | | - analyze_video_params = (PoseEstimationTask & key).fetch1( |
| 738 | + pose_estimation_params = (PoseEstimationTask & key).fetch1( |
737 | 739 | "pose_estimation_params" |
738 | 740 | ) or {} |
739 | 741 |
|
| 742 | + # expect a nested dictionary with "analyze_videos" params |
| 743 | + # if not, assume "pose_estimation_params" as a flat dictionary that include relevant "analyze_videos" params |
| 744 | + analyze_video_params = ( |
| 745 | + pose_estimation_params.get("analyze_videos") or pose_estimation_params |
| 746 | + ) |
| 747 | + |
740 | 748 | @memoized_result( |
741 | 749 | uniqueness_dict={ |
742 | 750 | **analyze_video_params, |
@@ -867,6 +875,102 @@ def get_trajectory(cls, key: dict, body_parts: list = "all") -> pd.DataFrame: |
867 | 875 | return df |
868 | 876 |
|
869 | 877 |
|
| 878 | +@schema |
| 879 | +class LabeledVideo(dj.Computed): |
| 880 | + definition = """ |
| 881 | + -> PoseEstimation |
| 882 | + """ |
| 883 | + |
| 884 | + class File(dj.Part): |
| 885 | + definition = """ |
| 886 | + -> master |
| 887 | + -> VideoRecording.File |
| 888 | + --- |
| 889 | + labeled_video_path: varchar(255) # relative path to labeled video |
| 890 | + """ |
| 891 | + |
| 892 | + @property |
| 893 | + def key_source(self): |
| 894 | + return PoseEstimation & RecordingInfo |
| 895 | + |
| 896 | + def make(self, key): |
| 897 | + import deeplabcut |
| 898 | + |
| 899 | + pose_estimation_params = (PoseEstimationTask & key).fetch1( |
| 900 | + "pose_estimation_params" |
| 901 | + ) or {} |
| 902 | + |
| 903 | + # expect a nested dictionary with "create_labeled_video" and "extract_outlier_frames" params |
| 904 | + # if not, assume "pose_estimation_params" as a flat dictionary |
| 905 | + create_labeled_video_params = ( |
| 906 | + pose_estimation_params.get("create_labeled_video") or pose_estimation_params |
| 907 | + ) |
| 908 | + |
| 909 | + outputframerate = create_labeled_video_params.pop( |
| 910 | + "outputframerate", 5 |
| 911 | + ) # final labeled video FPS defaults to 5 Hz |
| 912 | + |
| 913 | + dlc_model_ = (Model & key).fetch1() |
| 914 | + fps, nframes = (RecordingInfo & key).fetch1("fps", "nframes") |
| 915 | + output_dir = (PoseEstimationTask & key).fetch1("pose_estimation_output_dir") |
| 916 | + output_dir = find_full_path(get_dlc_root_data_dir(), output_dir) |
| 917 | + |
| 918 | + project_path = find_full_path( |
| 919 | + get_dlc_root_data_dir(), dlc_model_["project_path"] |
| 920 | + ) |
| 921 | + |
| 922 | + try: |
| 923 | + dlc_config = next(output_dir.glob("dj_dlc_config*.yaml")) |
| 924 | + dlc_config = project_path / dlc_config.name |
| 925 | + assert dlc_config.exists() |
| 926 | + except (StopIteration, AssertionError): |
| 927 | + dlc_config = next(project_path.glob("dj_dlc_config*.yaml")) |
| 928 | + logger.warning( |
| 929 | + f"No dj_dlc_config*.yaml file found in {output_dir} - this is unexpected.\nUsing {dlc_config}" |
| 930 | + ) |
| 931 | + |
| 932 | + entries = [] |
| 933 | + for vkey in (VideoRecording.File & key).fetch("KEY"): |
| 934 | + video_file = (VideoRecording.File & vkey).fetch1("file_path") |
| 935 | + video_file = find_full_path(get_dlc_root_data_dir(), video_file) |
| 936 | + |
| 937 | + # -- create labeled video -- |
| 938 | + create_labeled_video_kwargs = { |
| 939 | + k: v |
| 940 | + for k, v in create_labeled_video_params.items() |
| 941 | + if k in inspect.signature(deeplabcut.create_labeled_video).parameters |
| 942 | + } |
| 943 | + create_labeled_video_kwargs.update( |
| 944 | + dict( |
| 945 | + config=dlc_config.as_posix(), |
| 946 | + videos=[video_file.as_posix()], |
| 947 | + shuffle=dlc_model_["shuffle"], |
| 948 | + trainingsetindex=dlc_model_["trainingsetindex"], |
| 949 | + modelprefix=dlc_model_["model_prefix"], |
| 950 | + destfolder=output_dir.as_posix(), |
| 951 | + Frames2plot=np.arange(0, nframes, int(fps / outputframerate)), |
| 952 | + outputframerate=outputframerate, |
| 953 | + ) |
| 954 | + ) |
| 955 | + deeplabcut.create_labeled_video(**create_labeled_video_kwargs) |
| 956 | + |
| 957 | + labeled_video_path = next( |
| 958 | + output_dir.glob(f"{video_file.stem}*_labeled.mp4") |
| 959 | + ) |
| 960 | + entries.append( |
| 961 | + { |
| 962 | + **key, |
| 963 | + **vkey, |
| 964 | + "labeled_video_path": labeled_video_path.relative_to( |
| 965 | + get_dlc_processed_data_dir() |
| 966 | + ).as_posix(), |
| 967 | + } |
| 968 | + ) |
| 969 | + |
| 970 | + self.insert1(key) |
| 971 | + self.File.insert(entries) |
| 972 | + |
| 973 | + |
870 | 974 | def str_to_bool(value) -> bool: |
871 | 975 | """Return whether the provided string represents true. Otherwise false. |
872 | 976 |
|
|
0 commit comments