1111
1212from .base import Manager
1313from .model_manager import ModelManager
14-
14+ from pathlib import Path
15+ import os
1516
1617def get_orientation_vector (cls , b1_name , b2_name ):
1718 b1 = cls .get_keypoints ()[:, :, cls .get_bodypart_index (b1_name ), :]
@@ -79,12 +80,18 @@ def __init__(self, config: Dict[str, str], model_manager: ModelManager):
7980 self .model_manager = model_manager
8081 self .animals : List [AnimalSeq ] = []
8182 self .full_keypoint_names = []
82- keypoint_info = config ["keypoint_info" ]
83+
84+ self .init_pose ()
85+
86+ def init_pose (self ):
87+ keypoint_info = self .config ["keypoint_info" ]
88+
8389 if keypoint_info ["keypoint_file_path" ] is None :
8490 # no need to initialize here
8591 return
8692 else :
87- self .keypoint_file_path = config ["keypoint_info" ]["keypoint_file_path" ]
93+ self .keypoint_file_path = self .config ["keypoint_info" ]["keypoint_file_path" ]
94+
8895 if self .keypoint_file_path .endswith (".h5" ):
8996 all_keypoints = self ._process_keypoint_file_from_h5 ()
9097 elif self .keypoint_file_path .endswith (".json" ):
@@ -97,17 +104,17 @@ def __init__(self, config: Dict[str, str], model_manager: ModelManager):
97104 animalseq = AnimalSeq (
98105 animal_name , all_keypoints [:, individual_id ], self .keypoint_names
99106 )
100- if "body_orientation_keypoints" in config ["keypoint_info" ]:
107+ if "body_orientation_keypoints" in self . config ["keypoint_info" ]:
101108 animalseq .set_body_orientation_keypoints (
102- config ["keypoint_info" ]["body_orientation_keypoints" ]
109+ self . config ["keypoint_info" ]["body_orientation_keypoints" ]
103110 )
104- if "head_orientation_keypoints" in config ["keypoint_info" ]:
111+ if "head_orientation_keypoints" in self . config ["keypoint_info" ]:
105112 animalseq .set_head_orientation_keypoints (
106- config ["keypoint_info" ]["head_orientation_keypoints" ]
113+ self . config ["keypoint_info" ]["head_orientation_keypoints" ]
107114 )
108115
109- self .animals .append (animalseq )
110-
116+ self .animals .append (animalseq )
117+
111118 def _process_keypoint_file_from_h5 (self ) -> ndarray :
112119 df = pd .read_hdf (self .keypoint_file_path )
113120 self .full_keypoint_names = list (
@@ -212,15 +219,25 @@ def get_keypoints(self) -> ndarray:
212219 """
213220 Get the keypoints of animals. The shape is of shape n_frames, n_individuals, n_kpts, n_dims
214221 """
215- use_superanimal = False
216- if use_superanimal :
222+
223+ keypoint_file_path = self .config ['keypoint_info' ]['keypoint_file_path' ]
224+ video_file_path = self .config ['video_info' ]['video_file_path' ]
225+ if os .path .exists (video_file_path ) and keypoint_file_path is None :
226+
217227 import deeplabcut
218228 from deeplabcut .modelzoo .video_inference import video_inference_superanimal
219229 superanimal_name = 'superanimal_topviewmouse_hrnetw32'
220- video_inference_superanimal (videos = [self .config ['video_info' ]['video_file_path' ]],
221- superanimal_name = superanimal_name ,
222- video_adapt = False ,
223- dest_folder = 'temp_pose' )
230+
231+ keypoint_file_path = video_file_path .replace ('.mp4' , '_' + superanimal_name + '.h5' )
232+ if not os .path .exists (keypoint_file_path ):
233+ video_inference_superanimal (videos = [self .config ['video_info' ]['video_file_path' ]],
234+ superanimal_name = superanimal_name ,
235+ video_adapt = False )
236+
237+ if os .path .exists (keypoint_file_path ):
238+
239+ self .config ['keypoint_info' ]['keypoint_file_path' ] = keypoint_file_path
240+ self .init_pose ()
224241
225242 ret = np .stack ([animal .get_keypoints () for animal in self .animals ], axis = 1 )
226243 return ret
0 commit comments