Skip to content

Commit 0775230

Browse files
committed
Edited installation files. SuperAnimal can be run if there is no keypoint file
1 parent aef5a90 commit 0775230

File tree

5 files changed

+41
-23
lines changed

5 files changed

+41
-23
lines changed

amadeusgpt/managers/animal_manager.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111

1212
from .base import Manager
1313
from .model_manager import ModelManager
14-
14+
from pathlib import Path
15+
import os
1516

1617
def get_orientation_vector(cls, b1_name, b2_name):
1718
b1 = cls.get_keypoints()[:, :, cls.get_bodypart_index(b1_name), :]
@@ -79,12 +80,18 @@ def __init__(self, config: Dict[str, str], model_manager: ModelManager):
7980
self.model_manager = model_manager
8081
self.animals: List[AnimalSeq] = []
8182
self.full_keypoint_names = []
82-
keypoint_info = config["keypoint_info"]
83+
84+
self.init_pose()
85+
86+
def init_pose(self):
87+
keypoint_info = self.config["keypoint_info"]
88+
8389
if keypoint_info["keypoint_file_path"] is None:
8490
# no need to initialize here
8591
return
8692
else:
87-
self.keypoint_file_path = config["keypoint_info"]["keypoint_file_path"]
93+
self.keypoint_file_path = self.config["keypoint_info"]["keypoint_file_path"]
94+
8895
if self.keypoint_file_path.endswith(".h5"):
8996
all_keypoints = self._process_keypoint_file_from_h5()
9097
elif self.keypoint_file_path.endswith(".json"):
@@ -97,17 +104,17 @@ def __init__(self, config: Dict[str, str], model_manager: ModelManager):
97104
animalseq = AnimalSeq(
98105
animal_name, all_keypoints[:, individual_id], self.keypoint_names
99106
)
100-
if "body_orientation_keypoints" in config["keypoint_info"]:
107+
if "body_orientation_keypoints" in self.config["keypoint_info"]:
101108
animalseq.set_body_orientation_keypoints(
102-
config["keypoint_info"]["body_orientation_keypoints"]
109+
self.config["keypoint_info"]["body_orientation_keypoints"]
103110
)
104-
if "head_orientation_keypoints" in config["keypoint_info"]:
111+
if "head_orientation_keypoints" in self.config["keypoint_info"]:
105112
animalseq.set_head_orientation_keypoints(
106-
config["keypoint_info"]["head_orientation_keypoints"]
113+
self.config["keypoint_info"]["head_orientation_keypoints"]
107114
)
108115

109-
self.animals.append(animalseq)
110-
116+
self.animals.append(animalseq)
117+
111118
def _process_keypoint_file_from_h5(self) -> ndarray:
112119
df = pd.read_hdf(self.keypoint_file_path)
113120
self.full_keypoint_names = list(
@@ -212,15 +219,25 @@ def get_keypoints(self) -> ndarray:
212219
"""
213220
Get the keypoints of animals. The shape is of shape n_frames, n_individuals, n_kpts, n_dims
214221
"""
215-
use_superanimal = False
216-
if use_superanimal:
222+
223+
keypoint_file_path = self.config['keypoint_info']['keypoint_file_path']
224+
video_file_path = self.config['video_info']['video_file_path']
225+
if os.path.exists(video_file_path) and keypoint_file_path is None:
226+
217227
import deeplabcut
218228
from deeplabcut.modelzoo.video_inference import video_inference_superanimal
219229
superanimal_name = 'superanimal_topviewmouse_hrnetw32'
220-
video_inference_superanimal(videos = [self.config['video_info']['video_file_path']],
221-
superanimal_name = superanimal_name,
222-
video_adapt = False,
223-
dest_folder = 'temp_pose')
230+
231+
keypoint_file_path = video_file_path.replace('.mp4', '_' + superanimal_name + '.h5')
232+
if not os.path.exists(keypoint_file_path):
233+
video_inference_superanimal(videos = [self.config['video_info']['video_file_path']],
234+
superanimal_name = superanimal_name,
235+
video_adapt = False)
236+
237+
if os.path.exists(keypoint_file_path):
238+
239+
self.config['keypoint_info']['keypoint_file_path'] = keypoint_file_path
240+
self.init_pose()
224241

225242
ret = np.stack([animal.get_keypoints() for animal in self.animals], axis=1)
226243
return ret

install_cpu.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
source /Users/shaokaiye/miniforge3/bin/activate
33
conda env create -f conda/amadesuGPT-cpu.yml
44
conda activate amadeusgpt-cpu
5-
conda install pytorch cpuonly -c pytorch
6-
5+
conda install pytorch torchvision cpuonly -c pytorch
6+
pip install "git+https://github.com/DeepLabCut/DeepLabCut.git@pytorch_dlc#egg=deeplabcut"
7+
pip install pycocotools
78
pip install -e .[streamlit]

install_gpu.sh

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
#!/bin/bash
2-
source /Users/shaokaiye/miniforge3/bin/activate
2+
source /mnt/md0/shaokai/miniconda3/bin/activate
33
conda env create -f conda/amadesuGPT-gpu.yml
44
conda activate amadeusgpt-gpu
5-
conda install pytorch cudatoolkit=11.8 -c pytorch
6-
pip install "git+https://github.com/DeepLabCut/DeepLabCut.git@pytorch_dlc#egg=deeplabcut[gui,modelzoo,wandb]"
7-
8-
5+
# adjust this line according to your cuda version
6+
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
7+
pip install "git+https://github.com/DeepLabCut/DeepLabCut.git@pytorch_dlc#egg=deeplabcut"
8+
pip install pycocotools
99
pip install -e .[streamlit]
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,5 @@
22
source /Users/shaokaiye/miniforge3/bin/activate
33
conda env create -f conda/amadesuGPT-minimal.yml
44
conda activate amadeusgpt-minimal
5-
5+
pip install pycocotools
66
pip install -e .[streamlit]

0 commit comments

Comments
 (0)