|
17 | 17 | import pandas as pd |
18 | 18 | import click |
19 | 19 |
|
20 | | -from utils.configloader import RESOLUTION, FRAMERATE, OUT_DIR, MODEL, MULTI_CAM, STACK_FRAMES, \ |
| 20 | +from utils.configloader import RESOLUTION, FRAMERATE, OUT_DIR, MODEL_NAME, MULTI_CAM, STACK_FRAMES, \ |
21 | 21 | ANIMALS_NUMBER, STREAMS, STREAMING_SOURCE |
22 | | -from utils.poser import load_deeplabcut, get_pose, find_local_peaks_new, calculate_skeletons |
| 22 | +from utils.poser import load_deeplabcut, get_pose, find_local_peaks_new, calculate_skeletons,\ |
| 23 | + get_ma_pose, calculate_ma_skeletons, calculate_skeletons_dlc_live, transform_2skeleton |
23 | 24 | from utils.plotter import plot_bodyparts, plot_metadata_frame |
24 | 25 |
|
25 | 26 |
|
@@ -266,19 +267,61 @@ def write_video(self, frames: dict, index: int): |
266 | 267 | ###################### |
267 | 268 | @staticmethod |
268 | 269 | def get_pose_mp(input_q, output_q): |
| 270 | + from utils.configloader import MODEL_ORIGIN |
| 271 | + from utils.poser import get_ma_pose |
| 272 | + |
269 | 273 | """ |
270 | 274 | Process to be used for each camera/DLC stream of analysis |
271 | 275 | Designed to be run in an infinite loop |
272 | 276 | :param input_q: index and corresponding frame |
273 | 277 | :param output_q: index and corresponding analysis |
274 | 278 | """ |
275 | | - config, sess, inputs, outputs = load_deeplabcut() |
276 | | - while True: |
277 | | - if input_q.full(): |
278 | | - index, frame = input_q.get() |
279 | | - scmap, locref, pose = get_pose(frame, config, sess, inputs, outputs) |
280 | | - peaks = find_local_peaks_new(scmap, locref, ANIMALS_NUMBER, config) |
281 | | - output_q.put((index, peaks)) |
| 279 | + |
| 280 | + if MODEL_ORIGIN == 'DLC' or MODEL_ORIGIN == 'MADLC': |
| 281 | + config, sess, inputs, outputs = load_deeplabcut() |
| 282 | + while True: |
| 283 | + if input_q.full(): |
| 284 | + index, frame = input_q.get() |
| 285 | + if MODEL_ORIGIN == 'DLC': |
| 286 | + scmap, locref, pose = get_pose(frame, config, sess, inputs, outputs) |
| 287 | + # TODO: REmove alterations to original |
| 288 | + #peaks = find_local_peaks_new(scmap, locref, ANIMALS_NUMBER, config) |
| 289 | + peaks = pose |
| 290 | + if MODEL_ORIGIN == 'MADLC': |
| 291 | + peaks = get_ma_pose(frame, config, sess, inputs, outputs) |
| 292 | + |
| 293 | + output_q.put((index, peaks)) |
| 294 | + |
| 295 | + elif MODEL_ORIGIN == 'DLC-LIVE': |
| 296 | + from dlclive import DLCLive |
| 297 | + from utils.configloader import MODEL_PATH |
| 298 | + dlc_live = DLCLive(MODEL_PATH) |
| 299 | + while True: |
| 300 | + if input_q.full(): |
| 301 | + index, frame = input_q.get() |
| 302 | + if not dlc_live.is_initialized: |
| 303 | + peaks = dlc_live.init_inference(frame) |
| 304 | + else: |
| 305 | + peaks = dlc_live.get_pose(frame) |
| 306 | + |
| 307 | + output_q.put((index, peaks)) |
| 308 | + elif MODEL_ORIGIN == 'DEEPPOSEKIT': |
| 309 | + from deepposekit.models import load_model |
| 310 | + from utils.configloader import MODEL_PATH |
| 311 | + model = load_model(MODEL_PATH) |
| 312 | + predict_model = model.predict_model |
| 313 | + while True: |
| 314 | + if input_q.full(): |
| 315 | + index, frame = input_q.get() |
| 316 | + frame = frame[..., 1][..., None] |
| 317 | + st_frame = np.stack([frame]) |
| 318 | + prediction = predict_model.predict(st_frame, batch_size=1, verbose=True) |
| 319 | + peaks= prediction[0,:,:2] |
| 320 | + output_q.put((index, peaks)) |
| 321 | + |
| 322 | + |
| 323 | + else: |
| 324 | + raise ValueError(f'Model origin {MODEL_ORIGIN} not available.') |
282 | 325 |
|
283 | 326 | @staticmethod |
284 | 327 | def create_mp_tools(devices): |
@@ -366,7 +409,9 @@ def get_analysed_frames(self) -> tuple: |
366 | 409 |
|
367 | 410 | # Getting the analysed data |
368 | 411 | analysed_index, peaks = self._multiprocessing[camera]['output'].get() |
369 | | - skeletons = calculate_skeletons(peaks, ANIMALS_NUMBER) |
| 412 | + #TODO: REMOVE IF USELESS |
| 413 | + skeletons = [transform_2skeleton(peaks)] |
| 414 | + #skeletons = calculate_skeletons(peaks, ANIMALS_NUMBER) |
370 | 415 | print('', end='\r', flush=True) # this is the line you should not remove |
371 | 416 | analysed_frame, depth_map, input_time = self.get_stored_frames(camera) |
372 | 417 | analysis_time = time.time() - input_time |
@@ -722,7 +767,7 @@ def show_benchmark_statistics(): |
722 | 767 |
|
723 | 768 | if benchmark_enabled: |
724 | 769 | import re |
725 | | - short_model = re.split('[-_]', MODEL) |
| 770 | + short_model = re.split('[-_]', MODEL_NAME) |
726 | 771 | short_model = short_model[0] + '_' + short_model[2] |
727 | 772 | np.savetxt(f'{OUT_DIR}/{short_model}_framerate_{FRAMERATE}_resolution_{RESOLUTION[0]}_{RESOLUTION[1]}.txt', np.transpose([fps_data, whole_loop_time_data])) |
728 | 773 |
|
|
0 commit comments