|
16 | 16 | from scipy.ndimage.measurements import label, maximum_position |
17 | 17 | from scipy.ndimage.morphology import generate_binary_structure, binary_erosion |
18 | 18 | from scipy.ndimage.filters import maximum_filter |
19 | | -from utils.configloader import deeplabcut_config |
| 19 | +from utils.configloader import deeplabcut_config, MODEL_ORIGIN |
20 | 20 |
|
21 | 21 | MODEL = deeplabcut_config['model'] |
22 | 22 | DLC_PATH = deeplabcut_config['dlc_path'] |
|
25 | 25 | try: |
26 | 26 | import deeplabcut.pose_estimation_tensorflow.nnet.predict as predict |
27 | 27 | from deeplabcut.pose_estimation_tensorflow.config import load_config |
| 28 | + from deeplabcut.pose_estimation_tensorflow.nnet import predict_multianimal |
| 29 | + |
28 | 30 | models_folder = 'pose_estimation_tensorflow/models/' |
29 | 31 | # if not DLC 2 is not installed, try import from DLC 1 the old way |
30 | 32 | except ImportError: |
@@ -118,7 +120,7 @@ def find_local_peaks_new(scoremap: np.ndarray, local_reference: np.ndarray, anim |
118 | 120 | return all_peaks |
119 | 121 |
|
120 | 122 |
|
121 | | -def calculate_skeletons(peaks: dict, animals_number: int) -> list: |
| 123 | +def calculate_dlstream_skeletons(peaks: dict, animals_number: int) -> list: |
122 | 124 | """ |
123 | 125 | Creating skeletons from given peaks |
124 | 126 | There could be no more skeletons than animals_number |
@@ -188,6 +190,61 @@ def create_animal_skeleton(dots_cluster: tuple) -> dict: |
188 | 190 |
|
189 | 191 | return animal_skeletons |
190 | 192 |
|
| 193 | +"""maDLC""" |
| 194 | + |
| 195 | +def get_ma_pose(image, config, session, inputs, outputs): |
| 196 | + """ |
| 197 | + Gets scoremap, local reference and pose from DeepLabCut using given image |
| 198 | + Pose is most probable points for each joint, and not really used later |
| 199 | + Scoremap and local reference is essential to extract skeletons |
| 200 | + :param image: frame which would be analyzed |
| 201 | + :param config, session, inputs, outputs: DeepLabCut configuration and TensorFlow variables from load_deeplabcut() |
| 202 | +
|
| 203 | + :return: tuple of scoremap, local reference and pose |
| 204 | + """ |
| 205 | + scmap, locref, paf, pose = predict_multianimal.get_detectionswithcosts(image, config, session, inputs, outputs, outall=True, |
| 206 | + nms_radius=5.0, |
| 207 | + det_min_score=0.1, |
| 208 | + c_engine=False) |
| 209 | + return pose |
| 210 | + |
| 211 | +def calculate_ma_skeletons(pose: dict, animals_number: int) -> list: |
| 212 | + """ |
| 213 | + Creating skeletons from given pose in maDLC |
| 214 | + There could be no more skeletons than animals_number |
| 215 | + Only unique skeletons output |
| 216 | + """ |
| 217 | + |
| 218 | + def extract_to_animal_skeleton(coords): |
| 219 | + """ |
| 220 | + Creating a easy to read skeleton from dots cluster |
| 221 | + Format for each joint: |
| 222 | + {'joint_name': (x,y)} |
| 223 | + """ |
| 224 | + bodyparts = np.array(coords[0]) |
| 225 | + skeletons = {} |
| 226 | + for bp in range(len(bodyparts)): |
| 227 | + for animal_num in range(animals_number): |
| 228 | + if 'Mouse'+str(animal_num+1) not in skeletons.keys(): |
| 229 | + skeletons['Mouse' + str(animal_num + 1)] = {} |
| 230 | + if len(bodyparts[bp]) >= animals_number: |
| 231 | + skeletons['Mouse'+str(animal_num+1)]['bp' + str(bp + 1)] = bodyparts[bp][animal_num].astype(int) |
| 232 | + else: |
| 233 | + if animal_num < len(bodyparts[bp]): |
| 234 | + skeletons['Mouse'+str(animal_num+1)]['bp' + str(bp + 1)] = bodyparts[bp][animal_num].astype(int) |
| 235 | + else: |
| 236 | + skeletons['Mouse'+str(animal_num+1)]['bp' + str(bp + 1)] = np.array([0,0]) |
| 237 | + |
| 238 | + return skeletons |
| 239 | + animal_skeletons = extract_to_animal_skeleton(pose['coordinates']) |
| 240 | + # animal_skeletons = list(animal_skeletons.values()) |
| 241 | + |
| 242 | + return animal_skeletons |
| 243 | + |
| 244 | + |
| 245 | +"""DLC LIVE""" |
| 246 | + |
| 247 | + |
191 | 248 |
|
192 | 249 | def transform_2skeleton(pose): |
193 | 250 | from utils.configloader import ALL_BODYPARTS |
@@ -215,3 +272,28 @@ def calculate_skeletons_dlc_live(pose ,animals_number: int = 1) -> list: |
215 | 272 | skeletons = [transform_2skeleton(pose)] |
216 | 273 |
|
217 | 274 | return skeletons |
| 275 | + |
| 276 | + |
| 277 | +def calculate_skeletons(peaks: dict, animals_number: int) -> list: |
| 278 | + """ |
| 279 | + Creating skeletons from given peaks |
| 280 | + There could be no more skeletons than animals_number |
| 281 | + Only unique skeletons output |
| 282 | + adaptive to chosen model origin |
| 283 | + """ |
| 284 | + |
| 285 | + if MODEL_ORIGIN == 'DLC': |
| 286 | + animal_skeletons = calculate_dlstream_skeletons(peaks, animals_number) |
| 287 | + |
| 288 | + elif MODEL_ORIGIN == 'MADLC': |
| 289 | + animal_skeletons = calculate_ma_skeletons(peaks, animals_number) |
| 290 | + |
| 291 | + elif MODEL_ORIGIN == 'DLC-LIVE': |
| 292 | + animal_skeletons = calculate_skeletons_dlc_live(peaks, animals_number= 1) |
| 293 | + if animals_number != 1: |
| 294 | + raise ValueError('Multiple animals are currently not supported by DLC-LIVE.' |
| 295 | + ' If you are using differently colored animals, please refere to the bodyparts directly.') |
| 296 | + |
| 297 | + |
| 298 | + |
| 299 | + |
0 commit comments