Skip to content

Commit 5d1f850

Browse files
authored
Merge pull request #3 from SchwarzNeuroconLab/dev_performanceimprov
several performance improvements
2 parents 98c8da0 + e2934da commit 5d1f850

File tree

8 files changed

+185
-49
lines changed

8 files changed

+185
-49
lines changed

DeepLabStream.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from utils.generic import VideoManager, WebCamManager, GenericManager
2121
from utils.configloader import RESOLUTION, FRAMERATE, OUT_DIR, MODEL_NAME, MULTI_CAM, STACK_FRAMES, \
22-
ANIMALS_NUMBER, STREAMS, STREAMING_SOURCE, MODEL_ORIGIN
22+
ANIMALS_NUMBER, STREAMS, STREAMING_SOURCE, MODEL_ORIGIN, CROP, CROP_X, CROP_Y
2323
from utils.plotter import plot_bodyparts, plot_metadata_frame
2424
from utils.poser import load_deeplabcut, load_dpk, load_dlc_live, get_pose, calculate_skeletons,\
2525
find_local_peaks_new, get_ma_pose
@@ -281,38 +281,42 @@ def get_pose_mp(input_q, output_q):
281281
while True:
282282
if input_q.full():
283283
index, frame = input_q.get()
284+
start_time = time.time()
284285
if MODEL_ORIGIN == 'DLC':
285286
scmap, locref, pose = get_pose(frame, config, sess, inputs, outputs)
286287
# TODO: Remove alterations to original
287288
peaks = find_local_peaks_new(scmap, locref, ANIMALS_NUMBER, config)
288289
# peaks = pose
289290
if MODEL_ORIGIN == 'MADLC':
290291
peaks = get_ma_pose(frame, config, sess, inputs, outputs)
291-
292-
output_q.put((index, peaks))
292+
analysis_time = time.time() - start_time
293+
output_q.put((index, peaks, analysis_time))
293294

294295
elif MODEL_ORIGIN == 'DLC-LIVE':
295296
dlc_live = load_dlc_live()
296297
while True:
297298
if input_q.full():
298299
index, frame = input_q.get()
300+
start_time = time.time()
299301
if not dlc_live.is_initialized:
300302
peaks = dlc_live.init_inference(frame)
301303
else:
302304
peaks = dlc_live.get_pose(frame)
303-
304-
output_q.put((index, peaks))
305+
analysis_time = time.time() - start_time
306+
output_q.put((index, peaks, analysis_time))
305307

306308
elif MODEL_ORIGIN == 'DEEPPOSEKIT':
307309
predict_model = load_dpk()
308310
while True:
309311
if input_q.full():
310312
index, frame = input_q.get()
313+
start_time = time.time()
311314
frame = frame[..., 1][..., None]
312315
st_frame = np.stack([frame])
313316
prediction = predict_model.predict(st_frame, batch_size=1, verbose=True)
314317
peaks = prediction[0, :, :2]
315-
output_q.put((index, peaks))
318+
analysis_time = time.time() - start_time
319+
output_q.put((index,peaks,analysis_time))
316320
else:
317321
raise ValueError(f'Model origin {MODEL_ORIGIN} not available.')
318322

@@ -360,8 +364,12 @@ def get_frames(self) -> tuple:
360364
c_frames, d_maps, i_frames = self._camera_manager.get_frames()
361365
for camera in c_frames:
362366
c_frames[camera] = np.asanyarray(c_frames[camera])
367+
if CROP:
368+
c_frames[camera] = c_frames[camera][CROP_Y[0]:CROP_Y[1],CROP_X[0]:CROP_X[1]].copy()
369+
363370
for camera in i_frames:
364371
i_frames[camera] = np.asanyarray(i_frames[camera])
372+
365373
return c_frames, d_maps, i_frames
366374

367375
def input_frames_for_analysis(self, frames: tuple, index: int):
@@ -379,9 +387,9 @@ def input_frames_for_analysis(self, frames: tuple, index: int):
379387
frame_time = time.time()
380388
self._multiprocessing[camera]['input'].put((index, frame))
381389
if d_maps:
382-
self.store_frames(camera, frame, d_maps[camera], frame_time)
390+
self.store_frames(camera, frame, d_maps[camera], frame_time, index)
383391
else:
384-
self.store_frames(camera, frame, None, frame_time)
392+
self.store_frames(camera, frame, None, frame_time, index)
385393

386394
def get_analysed_frames(self) -> tuple:
387395
"""
@@ -401,12 +409,11 @@ def get_analysed_frames(self) -> tuple:
401409
self._start_time = time.time() # getting the first frame here
402410

403411
# Getting the analysed data
404-
analysed_index, peaks = self._multiprocessing[camera]['output'].get()
412+
analysed_index, peaks, analysis_time = self._multiprocessing[camera]['output'].get()
405413
skeletons = calculate_skeletons(peaks, ANIMALS_NUMBER)
406414
print('', end='\r', flush=True) # this is the line you should not remove
407-
analysed_frame, depth_map, input_time = self.get_stored_frames(camera)
408-
analysis_time = time.time() - input_time
409-
415+
analysed_frame , depth_map, input_time = self.get_stored_frames(camera, analysed_index)
416+
delay_time = time.time() - input_time
410417
# Calculating FPS and plotting the data on frame
411418
self.calculate_fps(analysis_time if analysis_time != 0 else 0.01)
412419
frame_time = time.time() - self._start_time
@@ -430,23 +437,30 @@ def get_analysed_frames(self) -> tuple:
430437
analysed_frames[camera] = analysed_image
431438
return analysed_frames, analysis_time
432439

433-
def store_frames(self, camera: str, c_frame, d_map, frame_time: float):
440+
def store_frames(self, camera: str, c_frame, d_map, frame_time: float, index: int):
434441
"""
435-
Store frames currently sent for analysis
442+
Store frames currently sent for analysis in index based dictionary
436443
:param camera: camera name
437444
:param c_frame: color frame
438445
:param d_map: depth map
439446
:param frame_time: inputting time of frameset
447+
:param index: index of frame that is currently analysed
440448
"""
441-
self._stored_frames[camera] = c_frame, d_map, frame_time
449+
if camera in self._stored_frames.keys():
450+
self._stored_frames[camera][index] = c_frame, d_map, frame_time
451+
452+
else:
453+
self._stored_frames[camera] = {}
454+
self._stored_frames[camera][index] = c_frame, d_map, frame_time
442455

443-
def get_stored_frames(self, camera: str):
456+
def get_stored_frames(self, camera: str, index: int):
444457
"""
445-
Retrieve frames currently sent for analysis
458+
Retrieve frames currently sent for analysis, retrieved frames will be removed (popped) from the dictionary
446459
:param camera: camera name
460+
:param index: index of analysed frame
447461
:return:
448462
"""
449-
c_frame, d_map, frame_time = self._stored_frames.get(camera)
463+
c_frame, d_map, frame_time = self._stored_frames[camera].pop(index, None)
450464
return c_frame, d_map, frame_time
451465

452466
def convert_depth_map_to_image(self, d_map):

app.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@
1111
import cv2
1212

1313
from DeepLabStream import DeepLabStream, show_stream
14+
from utils.generic import MissingFrameError
1415
from utils.configloader import MULTI_CAM, STREAMS, RECORD_EXP
1516
from utils.gui_image import QFrame, ImageWindow, emit_qframes
1617

17-
from PyQt5.QtCore import QThread
18-
from PyQt5.QtWidgets import QPushButton, QApplication, QWidget, QGridLayout
19-
from PyQt5.QtGui import QIcon
20-
18+
from PySide2.QtCore import QThread
19+
from PySide2.QtWidgets import QPushButton, QApplication, QWidget, QGridLayout
20+
from PySide2.QtGui import QIcon
2121

2222
# creating a complete thread process to work in the background
2323
class AThread(QThread):
@@ -42,7 +42,16 @@ def run(self):
4242
Infinite loop with all the streaming, analysis and recording logic
4343
"""
4444
while self.threadactive:
45-
all_frames = stream_manager.get_frames()
45+
try:
46+
all_frames = stream_manager.get_frames()
47+
except MissingFrameError as e:
48+
"""catch missing frame, stop Thread and save what can be saved"""
49+
print(*e.args, '\nShutting down DLStream and saving data...')
50+
stream_manager.finish_streaming()
51+
stream_manager.stop_cameras()
52+
self.stop()
53+
break
54+
4655
color_frames, depth_maps, infra_frames = all_frames
4756

4857
# writing the video

settings.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ OUTPUT_DIRECTORY = /Output
55
#if you have connected multiple cameras (USB), you will need to select the number OpenCV has given them.
66
#Default is "0", which takes the first available camera.
77
CAMERA_SOURCE = 0
8-
#you can use "camera", "ipwebcam" or "video" to select your input source
8+
#you can use camera, ipwebcam or video to select your input source
99
STREAMING_SOURCE = video
1010

1111
[Pose Estimation]

utils/advanced_settings.ini

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,15 @@
44
STREAMS = color, depth, infrared
55
MULTIPLE_DEVICES = False
66
STACK_FRAMES = False
7-
ANIMALS_NUMBER = 1
7+
ANIMALS_NUMBER = 1
8+
9+
CROP = False
10+
CROP_X = 0, 50
11+
CROP_Y = 0, 50
12+
13+
[Pose Estimation]
14+
FLATTEN_MA = True
15+
HANDLE_MISSING = skip
16+
17+
[Video]
18+
REPEAT_VIDEO = True

utils/configloader.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,18 @@ def get_script_path():
2727
adv_cfg_path = os.path.join(os.path.dirname(__file__), 'advanced_settings.ini')
2828
with open(adv_cfg_path) as adv_cfg_file:
2929
adv_dsc_config.read_file(adv_cfg_file)
30+
# DeepLabCut
31+
#deeplabcut_config = dict(dsc_config.items('DeepLabCut'))
3032

3133
#poseestimation
3234
MODEL_ORIGIN = dsc_config['Pose Estimation'].get('MODEL_ORIGIN')
33-
MODEL_PATH = dsc_config['Pose Estimation'].get('MODEL_PATH')
35+
model_path_string = [str(part).strip() for part in dsc_config['Pose Estimation'].get('MODEL_PATH').split(',')]
36+
MODEL_PATH = model_path_string[0] if len(model_path_string) <= 1 else model_path_string
3437
MODEL_NAME = dsc_config['Pose Estimation'].get('MODEL_NAME')
3538
ALL_BODYPARTS = tuple(part for part in dsc_config['Pose Estimation'].get('ALL_BODYPARTS').split(','))
3639

3740
# Streaming items
41+
3842
try:
3943
RESOLUTION = tuple(int(part) for part in dsc_config['Streaming'].get('RESOLUTION').split(','))
4044
except ValueError:
@@ -67,3 +71,13 @@ def get_script_path():
6771
'STACK_FRAMES') is not None else False
6872
ANIMALS_NUMBER = adv_dsc_config['Streaming'].getint('ANIMALS_NUMBER') if adv_dsc_config['Streaming'].getint(
6973
'ANIMALS_NUMBER') is not None else 1
74+
75+
REPEAT_VIDEO = adv_dsc_config['Video'].getboolean('REPEAT_VIDEO')
76+
CROP = adv_dsc_config['Streaming'].getboolean('CROP')
77+
CROP_X = [int(str(part).strip()) for part in adv_dsc_config['Streaming'].get('CROP_X').split(',')]
78+
CROP_Y = [int(str(part).strip()) for part in adv_dsc_config['Streaming'].get('CROP_Y').split(',')]
79+
80+
FLATTEN_MA = adv_dsc_config['Pose Estimation'].getboolean('FLATTEN_MA')
81+
HANDLE_MISSING = adv_dsc_config['Pose Estimation'].get('HANDLE_MISSING')
82+
83+

utils/generic.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
import numpy as np
1313
import zmq
1414

15-
from utils.configloader import CAMERA_SOURCE, VIDEO_SOURCE, RESOLUTION, FRAMERATE, PORT
15+
from utils.configloader import CAMERA_SOURCE, VIDEO_SOURCE, RESOLUTION, FRAMERATE, PORT, REPEAT_VIDEO
1616

17+
class MissingFrameError(Exception):
18+
"""Custom expection to be raised when frame is not received. Should be caught in app.py and deeplabstream.py
19+
to stop dlstream gracefully"""
1720

1821
class GenericManager:
1922
"""
@@ -24,11 +27,13 @@ def __init__(self):
2427
Generic camera manager from video source
2528
Uses pure opencv
2629
"""
27-
source = CAMERA_SOURCE if CAMERA_SOURCE is not None else 0
30+
self._source = CAMERA_SOURCE if CAMERA_SOURCE is not None else 0
2831
self._manager_name = "generic"
2932
self._enabled_devices = {}
30-
self._camera = cv2.VideoCapture(int(source))
31-
self._camera_name = "Camera {}".format(source)
33+
self._camera = None
34+
#Will be called when enabling stream! Important for restart of stream
35+
#self._camera = cv2.VideoCapture(int(self._source))
36+
self._camera_name = "Camera {}".format(self._source)
3237

3338
def get_connected_devices(self) -> list:
3439
"""
@@ -48,6 +53,7 @@ def enable_stream(self, resolution, framerate, *args):
4853
(hopefully)
4954
"""
5055
width, height = resolution
56+
self._camera = cv2.VideoCapture(int(self._source))
5157
self._camera.set(cv2.CAP_PROP_FRAME_WIDTH, width)
5258
self._camera.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
5359
self._camera.set(cv2.CAP_PROP_FPS, framerate)
@@ -76,6 +82,9 @@ def get_frames(self) -> tuple:
7682
ret, image = self._camera.read()
7783
if ret:
7884
color_frames[self._camera_name] = image
85+
else:
86+
raise MissingFrameError('No frame was received from the camera. Make sure that the camera is connected '
87+
'and that the camera source is set correctly.')
7988

8089
return color_frames, depth_maps, infra_frames
8190

@@ -101,11 +110,22 @@ def __init__(self):
101110
Uses pure opencv
102111
"""
103112
super().__init__()
104-
self._camera = cv2.VideoCapture(VIDEO_SOURCE)
113+
#will be defined in enable_stream
114+
self._camera = None
105115
self._camera_name = "Video"
106116
self.initial_wait = False
107117
self.last_frame_time = time.time()
108118

119+
def enable_stream(self, resolution, framerate, *args):
120+
"""
121+
Enable one stream with given parameters
122+
(hopefully)
123+
"""
124+
# set video to first frame
125+
print('Thinking of beginning things...')
126+
self._camera = cv2.VideoCapture(VIDEO_SOURCE)
127+
self._camera.set(cv2.CAP_PROP_POS_FRAMES,0)
128+
109129
def get_frames(self) -> tuple:
110130
"""
111131
Collect frames for camera and outputs it in 'color' dictionary
@@ -128,10 +148,12 @@ def get_frames(self) -> tuple:
128148
if running_time <= 1 / FRAMERATE:
129149
sleepy_time = int(np.ceil(1000/FRAMERATE - running_time / 1000))
130150
cv2.waitKey(sleepy_time)
131-
else:
151+
elif REPEAT_VIDEO:
132152
# cycle the video for testing purposes
133153
self._camera.set(cv2.CAP_PROP_POS_FRAMES, 0)
134154
return self.get_frames()
155+
else:
156+
raise MissingFrameError('The video reached the end or is damaged. Use REPEAT_VIDEO in the advanced_settings to repeat videos.')
135157

136158
return color_frames, depth_maps, infra_frames
137159

@@ -194,6 +216,10 @@ def get_frames(self) -> tuple:
194216
if running_time <= 1 / FRAMERATE:
195217
sleepy_time = int(np.ceil(1000/FRAMERATE - running_time / 1000))
196218
cv2.waitKey(sleepy_time)
219+
220+
else:
221+
raise MissingFrameError('No frame was received from the webcam stream. Make sure that you started streaming on the host machine.')
222+
197223
return color_frames, depth_maps, infra_frames
198224

199225
def enable_stream(self, resolution, framerate, *args):

utils/plotter.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"""
88

99
import cv2
10+
import numpy as np
1011

1112

1213
def plot_dots(image, coordinates, color, cond=False):
@@ -29,15 +30,19 @@ def plot_bodyparts(image, skeletons):
2930
# predefined colors list
3031
colors_list = [(0, 0, 255), (0, 255, 0), (0, 255, 255), (255, 0, 0), (255, 0, 255), (255, 255, 0), (255, 255, 128),
3132
(0, 0, 128), (0, 128, 0), (0, 128, 128), (0, 128, 255), (0, 255, 128), (128, 0, 0), (128, 0, 128),
32-
(128, 0, 255), (128, 128, 0), (128, 128, 128), (128, 128, 255), (128, 255, 0), (128, 255, 128),
33-
(128, 255, 255), (255, 0, 128), (255, 128, 0), (255, 128, 128), (255, 128, 255)]
33+
(128, 0, 255), (128, 128, 0), (128, 128, 128), (128, 128, 255), (128, 255, 0), (128, 255, 128),
34+
(128, 255, 255), (255, 0, 128), (255, 128, 0), (255, 128, 128), (255, 128, 255)]
35+
#color = (255, 0, 0)
3436

35-
for animal in skeletons:
37+
for num, animal in enumerate(skeletons):
3638
bodyparts = animal.keys()
3739
bp_count = len(bodyparts)
38-
colors = dict(zip(bodyparts, colors_list[:bp_count]))
40+
#colors = dict(zip(bodyparts, colors_list[:bp_count]))
3941
for part in animal:
40-
plot_dots(res_image, tuple(animal[part]), colors[part])
42+
#check for NaNs and skip
43+
if not any(np.isnan(animal[part])):
44+
plot_dots(res_image, tuple(map(int, animal[part])), colors_list[num])
45+
#plot_dots(res_image, tuple(animal[part]), colors[part])
4146
return res_image
4247

4348

0 commit comments

Comments
 (0)