Skip to content

Commit 112a6ab

Browse files
committed
Merge remote-tracking branch 'origin/dev' into dev
2 parents e6321e9 + 45246d0 commit 112a6ab

File tree

9 files changed

+209
-47
lines changed

9 files changed

+209
-47
lines changed

DeepLabStream.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,19 @@ def get_pose_mp(input_q, output_q):
302302

303303
output_q.put((index, peaks))
304304
elif MODEL_ORIGIN == 'DEEPPOSEKIT':
305-
print('Not here yet...')
305+
from deepposekit.models import load_model
306+
from utils.configloader import MODEL_PATH
307+
model = load_model(MODEL_PATH)
308+
predict_model = model.predict_model
309+
while True:
310+
if input_q.full():
311+
index, frame = input_q.get()
312+
frame = frame[..., 1][..., None]
313+
st_frame = np.stack([frame])
314+
prediction = predict_model.predict(st_frame, batch_size=1, verbose=True)
315+
peaks= prediction[0,:,:2]
316+
output_q.put((index, peaks))
317+
306318

307319
else:
308320
raise ValueError(f'Model origin {MODEL_ORIGIN} not available.')

experiments/base/stimulation.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,35 @@ def __init__(self):
1919
self._name = 'BaseStimulation'
2020
self._parameter_dict = dict(TYPE='str',
2121
PORT='str',
22+
IP = 'str',
2223
STIM_TIME='float')
2324
self._settings_dict = get_stimulation_settings(self._name, self._parameter_dict)
2425
self._running = False
25-
self._stim_device = self._setup_device(self._settings_dict['TYPE'], self._settings_dict['PORT'])
26+
self._stim_device = self._setup_device(self._settings_dict['TYPE'], self._settings_dict['PORT'], self._settings_dict['IP'])
2627

2728
@staticmethod
28-
def _setup_device(type, port):
29+
def _setup_device(type, port, ip):
2930
device = None
3031
if type == 'NI':
3132
from experiments.utils.DAQ_output import DigitalModDevice
3233
device = DigitalModDevice(port)
3334

35+
if type == 'RASPBERRY':
36+
from experiments.utils.gpio_control import DigitalPiDevice
37+
device = DigitalPiDevice(port)
38+
39+
if type == 'RASP_NETWORK':
40+
from experiments.utils.gpio_control import DigitalPiDevice
41+
if ip is not None:
42+
device = DigitalPiDevice(port, ip)
43+
else:
44+
raise ValueError('IP required for remote GPIO control.')
45+
46+
if type == 'ARDUINO':
47+
from experiments.utils.gpio_control import DigitalArduinoDevice
48+
device = DigitalArduinoDevice(port)
49+
50+
3451
return device
3552

3653
def stimulate(self):
@@ -72,23 +89,38 @@ class RewardDispenser(BaseStimulation):
7289
def __init__(self):
7390
self._name = 'RewardDispenser'
7491
self._parameter_dict = dict(TYPE = 'str',
92+
IP = 'str',
7593
STIM_PORT= 'str',
7694
REMOVAL_PORT = 'str',
7795
STIM_TIME = 'float',
7896
REMOVAL_TIME = 'float')
7997
self._settings_dict = get_stimulation_settings(self._name, self._parameter_dict)
8098
self._running = False
81-
self._stim_device = self._setup_device(self._settings_dict['TYPE'], self._settings_dict['STIM_PORT'])
82-
self._removal_device = self._setup_device(self._settings_dict['TYPE'], self._settings_dict['REMOVAL_PORT'])
99+
self._stim_device = self._setup_device(self._settings_dict['TYPE'], self._settings_dict['STIM_PORT'],
100+
self._settings_dict['IP'])
101+
self._removal_device = self._setup_device(self._settings_dict['TYPE'], self._settings_dict['REMOVAL_PORT'],
102+
self._settings_dict['IP'])
83103

84104

85105
@staticmethod
86-
def _setup_device(type, port):
106+
def _setup_device(type, port, ip):
87107
device = None
88108
if type == 'NI':
89109
from experiments.utils.DAQ_output import DigitalModDevice
90110
device = DigitalModDevice(port)
91111

112+
if type == 'RASPBERRY':
113+
from experiments.utils.gpio_control import DigitialPiBoardDevice
114+
device = DigitialPiBoardDevice(port)
115+
116+
if type == 'RASP_NETWORK':
117+
from experiments.utils.gpio_control import DigitialPiBoardDevice
118+
if ip is not None:
119+
device = DigitialPiBoardDevice(port, ip)
120+
else:
121+
raise ValueError('IP required for remote GPIO control.')
122+
123+
92124
return device
93125

94126
def stimulate(self):

experiments/configs/default_config.ini

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,20 @@ STIMULATION = BaseStimulation
9898
;[STIMULATION]
9999

100100
[BaseStimulation]
101+
; can be NI, RASPBERRY or RASP_NETWORK
101102
TYPE = NI
103+
;only used in RASP_NETWORK
104+
IP = None
105+
;PORT parameter is used for all (Port from DAQ, PIN from Raspberry, or serial port from Arduino)
102106
PORT = Dev1/PFI6
103107
STIM_TIME = 3.5
104108

105109
[RewardDispenser]
110+
; can be NI, RASPBERRY, RASP_NETWORK or ARDUINO
106111
TYPE = NI
112+
;only used in RASP_NETWORK
113+
IP = None
114+
;PORT parameter is used for all (Port from DAQ, PIN from Raspberry, or serial port from Arduino)
107115
STIM_PORT = Dev1/PFI6
108116
REMOVAL_PORT = Dev1/PFI5
109117
STIM_TIME = 3.5

experiments/utils/gpio_control.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from gpiozero import DigitalOutputDevice
2+
from gpiozero.pins.pigpio import PiGPIOFactory
3+
4+
import serial
5+
6+
7+
8+
class DigitalPiDevice:
9+
"""
10+
Digital modulated devices in combination with Raspberry Pi GPIO
11+
Setup: https://gpiozero.readthedocs.io/en/stable/remote_gpio.html
12+
"""
13+
14+
def __init__(self, PIN, BOARD_IP: str = None):
15+
16+
"""
17+
:param BOARD_IP: IP adress of board connected to the Device
18+
"""
19+
if BOARD_IP is not None:
20+
self._factory = PiGPIOFactory(host = BOARD_IP)
21+
self._device = DigitalOutputDevice(PIN, pin_factory = self._factory)
22+
else:
23+
self._factory = None
24+
self._device = DigitalOutputDevice(PIN)
25+
self._running = False
26+
27+
def turn_on(self):
28+
self._device.on()
29+
self._running = True
30+
31+
def turn_off(self):
32+
self._device.off()
33+
self._running = False
34+
35+
def toggle(self):
36+
self._device.toggle()
37+
self._running = self._device.is_active
38+
39+
40+
class DigitalArduinoDevice:
41+
"""
42+
Digital modulated devices in combination with Arduino boards connected via USB
43+
setup: https://pythonforundergradengineers.com/python-arduino-LED.html
44+
45+
"""
46+
47+
def __init__(self, PORT):
48+
"""
49+
:param PORT: USB PORT of the arduino board
50+
"""
51+
self._device = serial.Serial(PORT, baudrate=9600)
52+
self._running = False
53+
54+
def turn_on(self):
55+
self._device.write(b'H')
56+
self._running = True
57+
58+
def turn_off(self):
59+
self._device.write(b'L')
60+
self._running = False
61+
62+
def toggle(self):
63+
if self._running:
64+
self.turn_off()
65+
else:
66+
self.turn_on()

experiments/utils/gpio_raspberry.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

requirements.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1+
gpiozero
2+
pigpio
3+
pyserial
14
nidaqmx>=0.5.7
25
click>=7.0
36
opencv-python>=3.4.5.20
47
numpy>=1.14.5
58
pandas>=0.21.0
69
matplotlib>=3.0.3
710
scikit-image>=0.14.2
8-
scipy>=1.1.0
11+
scipy>=1.1.0

utils/VideoAnalyzer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from utils.configloader import VIDEO_SOURCE, OUT_DIR, ANIMALS_NUMBER
1919
from experiments.custom.experiments import ExampleExperiment
2020

21+
from deepposekit.models import load_model
22+
2123

2224
def create_dataframes(data_output):
2325
"""
@@ -43,7 +45,7 @@ def create_dataframes(data_output):
4345

4446
def start_videoanalyser():
4547
print("Starting DeepLabCut")
46-
config, sess, inputs, outputs = load_deeplabcut()
48+
model = load_model(r"D:\DeepPoseKit-Data-master\datasets\fly\best_model_densenet.h5")
4749

4850
experiment_enabled = False
4951
video_output = True

utils/poser.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def extract_to_animal_skeleton(coords):
240240
return animal_skeletons
241241

242242

243-
"""DLC LIVE"""
243+
"""DLC LIVE & DeepPoseKit"""
244244

245245

246246

@@ -286,11 +286,11 @@ def calculate_skeletons(peaks: dict, animals_number: int) -> list:
286286
elif MODEL_ORIGIN == 'MADLC':
287287
animal_skeletons = calculate_ma_skeletons(peaks, animals_number)
288288

289-
elif MODEL_ORIGIN == 'DLC-LIVE':
289+
elif MODEL_ORIGIN == 'DLC-LIVE' or MODEL_ORIGIN == 'DEEPPOSEKIT':
290290
animal_skeletons = calculate_skeletons_dlc_live(peaks, animals_number= 1)
291291
if animals_number != 1:
292292
raise ValueError('Multiple animals are currently not supported by DLC-LIVE.'
293-
' If you are using differently colored animals, please refere to the bodyparts directly.')
293+
' If you are using differently colored animals, please refer to the bodyparts directly.')
294294

295295
return animal_skeletons
296296

utils/videoanalyzer_dpp.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""
2+
DeepLabStream
3+
© J.Schweihoff, M. Loshakov
4+
University Bonn Medical Faculty, Germany
5+
https://github.com/SchwarzNeuroconLab/DeepLabStream
6+
Licensed under GNU General Public License v3.0
7+
"""
8+
9+
import cv2
10+
from deepposekit.models import load_model
11+
from deepposekit.io import VideoReader
12+
import numpy as np
13+
import tensorflow as tf
14+
15+
16+
17+
def plot_dlc_bodyparts(image, bodyparts):
18+
"""
19+
Plots dlc bodyparts on given image
20+
adapted from plotter
21+
"""
22+
23+
for bp in bodyparts:
24+
center = tuple(bp.astype(int))
25+
cv2.circle(image, center=center, radius=3, color=(255, 0, 0), thickness=2)
26+
return image
27+
28+
def start_videoanalyser():
29+
print("Starting DeepPoseKit")
30+
video = cv2.VideoCapture(r"D:\DeepPoseKit-Data-master\datasets\fly\video.avi")
31+
resolution = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)), int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
32+
33+
path_model = r"D:\DeepPoseKit-Data-master\datasets\fly\best_model_densenet.h5"
34+
model = load_model(path_model)
35+
36+
37+
predict_model = model.predict_model
38+
# predict_model.layers.pop(0) # remove current input layer
39+
#
40+
# inputs = tf.keras.layers.Input((resolution[0], resolution[1], 3))
41+
# outputs = predict_model(inputs)
42+
# predict_model = tf.keras.Model(inputs, outputs)
43+
44+
45+
experiment_enabled = False
46+
47+
index = 0
48+
while video.isOpened():
49+
ret, frame = video.read()
50+
if ret is not None:
51+
org_frame = frame
52+
frame = frame[..., 1][..., None]
53+
st_frame = np.stack([frame])
54+
prediction = predict_model.predict(st_frame, batch_size= 1, verbose=True)
55+
x, y, confidence = np.split(prediction, 3, -1)
56+
57+
print(prediction.shape)
58+
predi = prediction[0,:,:2]
59+
pre = predi[:, :2]
60+
print(pre)
61+
out_frame = plot_dlc_bodyparts(org_frame, predi)
62+
# out_frame = org_frame
63+
cv2.imshow('stream', out_frame)
64+
index += 1
65+
else:
66+
break
67+
68+
if cv2.waitKey(1) & 0xFF == ord('q'):
69+
break
70+
71+
video.release()
72+
73+
74+
if __name__ == "__main__":
75+
start_videoanalyser()

0 commit comments

Comments
 (0)