Skip to content

Commit ccb38d8

Browse files
Merge pull request #2 from SchwarzNeuroconLab/fix/various-clean-up-fixes
Cleaned up the imports
2 parents 1900d7a + f94cafc commit ccb38d8

File tree

4 files changed

+118
-87
lines changed

4 files changed

+118
-87
lines changed

DeepLabStream.py

Lines changed: 38 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717
import numpy as np
1818
import pandas as pd
1919

20-
from utils.configloader import RESOLUTION,FRAMERATE,OUT_DIR,MODEL_NAME,MULTI_CAM,STACK_FRAMES, \
21-
ANIMALS_NUMBER,STREAMS,STREAMING_SOURCE
22-
from utils.plotter import plot_bodyparts,plot_metadata_frame
23-
from utils.poser import load_deeplabcut,get_pose,calculate_skeletons
20+
from utils.generic import VideoManager, WebCamManager, GenericManager
21+
from utils.configloader import RESOLUTION, FRAMERATE, OUT_DIR, MODEL_NAME, MULTI_CAM, STACK_FRAMES, \
22+
ANIMALS_NUMBER, STREAMS, STREAMING_SOURCE, MODEL_ORIGIN
23+
from utils.plotter import plot_bodyparts, plot_metadata_frame
24+
from utils.poser import load_deeplabcut, load_dpk, load_dlc_live, get_pose, calculate_skeletons,\
25+
find_local_peaks_new, get_ma_pose
2426

2527

2628
def create_video_files(directory, devices, resolution, framerate, codec):
@@ -127,28 +129,19 @@ def set_camera_manager():
127129
:return: the chosen camera manager
128130
"""
129131

130-
if STREAMING_SOURCE.lower() == 'video':
131-
from utils.generic import VideoManager
132-
manager = VideoManager()
133-
return manager
134-
135-
elif STREAMING_SOURCE.lower() == 'ipwebcam':
136-
from utils.generic import WebCamManager
137-
manager = WebCamManager()
138-
return manager
139-
140-
elif STREAMING_SOURCE.lower() == 'camera':
132+
def select_camera_manager():
133+
"""
134+
Function to select from all available camera managers
135+
"""
141136
manager_list = []
142137
# loading realsense manager, if installed
143-
realsense = find_spec("pyrealsense2") is not None
144-
if realsense:
138+
if find_spec("pyrealsense2") is not None:
145139
from utils.realsense import RealSenseManager
146140
realsense_manager = RealSenseManager()
147141
manager_list.append(realsense_manager)
148142

149143
# loading basler manager, if installed
150-
pylon = find_spec("pypylon") is not None
151-
if pylon:
144+
if find_spec("pypylon") is not None:
152145
from utils.pylon import PylonManager
153146
pylon_manager = PylonManager()
154147
manager_list.append(pylon_manager)
@@ -170,9 +163,19 @@ def check_for_cameras(camera_manager):
170163
return manager
171164
else:
172165
# if no camera is found, try generic openCV manager
173-
from utils.generic import GenericManager
174166
generic_manager = GenericManager()
175167
return generic_manager
168+
169+
MANAGER_SOURCE = {
170+
'video': VideoManager,
171+
'ipwebcam': WebCamManager,
172+
'camera': select_camera_manager
173+
}
174+
175+
# initialize selected manager
176+
camera_manager = MANAGER_SOURCE.get(STREAMING_SOURCE)()
177+
if camera_manager is not None:
178+
return camera_manager
176179
else:
177180
raise ValueError(f'Streaming source {STREAMING_SOURCE} is not a valid option. \n'
178181
f'Please choose from "video", "camera" or "ipwebcam".')
@@ -266,35 +269,30 @@ def write_video(self, frames: dict, index: int):
266269
######################
267270
@staticmethod
268271
def get_pose_mp(input_q, output_q):
269-
from utils.configloader import MODEL_ORIGIN
270-
from utils.poser import get_ma_pose
271-
272272
"""
273273
Process to be used for each camera/DLC stream of analysis
274274
Designed to be run in an infinite loop
275275
:param input_q: index and corresponding frame
276276
:param output_q: index and corresponding analysis
277277
"""
278278

279-
if MODEL_ORIGIN == 'DLC' or MODEL_ORIGIN == 'MADLC':
280-
from utils.poser import find_local_peaks_new
279+
if MODEL_ORIGIN in ('DLC', 'MADLC'):
281280
config, sess, inputs, outputs = load_deeplabcut()
282281
while True:
283282
if input_q.full():
284283
index, frame = input_q.get()
285284
if MODEL_ORIGIN == 'DLC':
286285
scmap, locref, pose = get_pose(frame, config, sess, inputs, outputs)
286+
# TODO: Remove alterations to original
287287
peaks = find_local_peaks_new(scmap, locref, ANIMALS_NUMBER, config)
288-
#peaks = pose
288+
# peaks = pose
289289
if MODEL_ORIGIN == 'MADLC':
290290
peaks = get_ma_pose(frame, config, sess, inputs, outputs)
291291

292292
output_q.put((index, peaks))
293293

294294
elif MODEL_ORIGIN == 'DLC-LIVE':
295-
from dlclive import DLCLive
296-
from utils.configloader import MODEL_PATH
297-
dlc_live = DLCLive(MODEL_PATH)
295+
dlc_live = load_dlc_live()
298296
while True:
299297
if input_q.full():
300298
index, frame = input_q.get()
@@ -304,21 +302,17 @@ def get_pose_mp(input_q, output_q):
304302
peaks = dlc_live.get_pose(frame)
305303

306304
output_q.put((index, peaks))
305+
307306
elif MODEL_ORIGIN == 'DEEPPOSEKIT':
308-
from deepposekit.models import load_model
309-
from utils.configloader import MODEL_PATH
310-
model = load_model(MODEL_PATH)
311-
predict_model = model.predict_model
307+
predict_model = load_dpk()
312308
while True:
313309
if input_q.full():
314310
index, frame = input_q.get()
315311
frame = frame[..., 1][..., None]
316312
st_frame = np.stack([frame])
317313
prediction = predict_model.predict(st_frame, batch_size=1, verbose=True)
318-
peaks= prediction[0,:,:2]
314+
peaks = prediction[0, :, :2]
319315
output_q.put((index, peaks))
320-
321-
322316
else:
323317
raise ValueError(f'Model origin {MODEL_ORIGIN} not available.')
324318

@@ -763,10 +757,14 @@ def show_benchmark_statistics():
763757
print("[{0}/3000] Benchmarking in progress".format(len(analysis_time_data)))
764758

765759
if benchmark_enabled:
766-
import re
767-
short_model = re.split('[-_]', MODEL_NAME)
768-
short_model = short_model[0] + '_' + short_model[2]
769-
np.savetxt(f'{OUT_DIR}/{short_model}_framerate_{FRAMERATE}_resolution_{RESOLUTION[0]}_{RESOLUTION[1]}.txt', np.transpose([fps_data, whole_loop_time_data]))
760+
model_parts = MODEL_NAME.split('_')
761+
if len(model_parts) == 3:
762+
short_model = model_parts[0] + '_' + model_parts[2]
763+
else:
764+
short_model = MODEL_NAME
765+
# the best way to save files
766+
np.savetxt(f'{OUT_DIR}/{short_model}_framerate_{FRAMERATE}_resolution_{RESOLUTION[0]}_{RESOLUTION[1]}.txt',
767+
np.transpose([fps_data, whole_loop_time_data]))
770768

771769

772770
if __name__ == '__main__':

requirements.txt

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1-
gpiozero
2-
pigpio
3-
pyserial
4-
nidaqmx>=0.5.7
5-
click>=7.0
6-
opencv-python>=3.4.5.20
1+
gpiozero==1.5.1
2+
pigpio==1.78
3+
pyserial==3.5
4+
nidaqmx==0.5.7
5+
click==7.1.2
6+
opencv-python==3.4.5.20
7+
opencv-contrib-python==4.4.0.46
78
numpy>=1.14.5
8-
pandas>=0.21.0
9-
matplotlib>=3.0.3
10-
scikit-image>=0.14.2
11-
scipy>=1.1.0
9+
pandas==1.1.4
10+
matplotlib==3.0.3
11+
scikit-image==0.17.2
12+
scipy==1.5.4
13+
pyzmq==20.0.0

utils/generic.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,15 @@
55
https://github.com/SchwarzNeuroconLab/DeepLabStream
66
Licensed under GNU General Public License v3.0
77
"""
8+
import time
9+
import base64
810

911
import cv2
10-
from utils.configloader import CAMERA_SOURCE, VIDEO_SOURCE, RESOLUTION, FRAMERATE, PORT
11-
import time
1212
import numpy as np
13+
import zmq
14+
15+
from utils.configloader import CAMERA_SOURCE, VIDEO_SOURCE, RESOLUTION, FRAMERATE, PORT
16+
1317

1418
class GenericManager:
1519
"""
@@ -86,7 +90,6 @@ def get_name(self) -> str:
8690
return self._manager_name
8791

8892

89-
9093
class VideoManager(GenericManager):
9194

9295
"""
@@ -97,13 +100,12 @@ def __init__(self):
97100
Generic video manager from video files
98101
Uses pure opencv
99102
"""
100-
self._manager_name = "generic"
103+
super().__init__()
101104
self._camera = cv2.VideoCapture(VIDEO_SOURCE)
102105
self._camera_name = "Video"
103106
self.initial_wait = False
104107
self.last_frame_time = time.time()
105108

106-
107109
def get_frames(self) -> tuple:
108110
"""
109111
Collect frames for camera and outputs it in 'color' dictionary
@@ -116,7 +118,6 @@ def get_frames(self) -> tuple:
116118
infra_frames = {}
117119
ret, image = self._camera.read()
118120
self.last_frame_time = time.time()
119-
#print(ret)
120121
if ret:
121122
if not self.initial_wait:
122123
cv2.waitKey(1000)
@@ -127,6 +128,10 @@ def get_frames(self) -> tuple:
127128
if running_time <= 1 / FRAMERATE:
128129
sleepy_time = int(np.ceil(1000/FRAMERATE - running_time / 1000))
129130
cv2.waitKey(sleepy_time)
131+
else:
132+
# cycle the video for testing purposes
133+
self._camera.set(cv2.CAP_PROP_POS_FRAMES, 0)
134+
return self.get_frames()
130135

131136
return color_frames, depth_maps, infra_frames
132137

@@ -138,24 +143,23 @@ def __init__(self):
138143
Binds the computer to a ip address and starts listening for incoming streams.
139144
Adapted from StreamViewer.py https://github.com/CT83/SmoothStream
140145
"""
141-
import zmq
146+
super().__init__()
142147
self._context = zmq.Context()
143148
self._footage_socket = self._context.socket(zmq.SUB)
144149
self._footage_socket.bind('tcp://*:' + PORT)
145150
self._footage_socket.setsockopt_string(zmq.SUBSCRIBE, np.unicode(''))
146151

147-
self._manager_name = "generic"
148152
self._camera = None
149153
self._camera_name = "webcam"
150154
self.initial_wait = False
151155
self.last_frame_time = time.time()
152156

153-
def string_to_image(self, string):
154-
""" Taken from https://github.com/CT83/SmoothStream"""
157+
@ staticmethod
158+
def string_to_image(string):
159+
"""
160+
Taken from https://github.com/CT83/SmoothStream
161+
"""
155162

156-
import numpy as np
157-
import cv2
158-
import base64
159163
img = base64.b64decode(string)
160164
npimg = np.fromstring(img, dtype=np.uint8)
161165
return cv2.imdecode(npimg, 1)

0 commit comments

Comments
 (0)