Skip to content

Commit c1fb0a8

Browse files
committed
updating model options
updating stimulation options
1 parent 7bb7ace commit c1fb0a8

File tree

6 files changed

+167
-10
lines changed

6 files changed

+167
-10
lines changed

DeepLabStream.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,19 @@ def get_pose_mp(input_q, output_q):
302302

303303
output_q.put((index, peaks))
304304
elif MODEL_ORIGIN == 'DEEPPOSEKIT':
305-
print('Not here yet...')
305+
from deepposekit.models import load_model
306+
from utils.configloader import MODEL_PATH
307+
model = load_model(MODEL_PATH)
308+
predict_model = model.predict_model
309+
while True:
310+
if input_q.full():
311+
index, frame = input_q.get()
312+
frame = frame[..., 1][..., None]
313+
st_frame = np.stack([frame])
314+
prediction = predict_model.predict(st_frame, batch_size=1, verbose=True)
315+
peaks= prediction[0,:,:2]
316+
output_q.put((index, peaks))
317+
306318

307319
else:
308320
raise ValueError(f'Model origin {MODEL_ORIGIN} not available.')

experiments/base/stimulation.py

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,35 @@ def __init__(self):
1919
self._name = 'BaseStimulation'
2020
self._parameter_dict = dict(TYPE='str',
2121
PORT='str',
22+
IP = 'str',
2223
STIM_TIME='float')
2324
self._settings_dict = get_stimulation_settings(self._name, self._parameter_dict)
2425
self._running = False
25-
self._stim_device = self._setup_device(self._settings_dict['TYPE'], self._settings_dict['PORT'])
26+
self._stim_device = self._setup_device(self._settings_dict['TYPE'], self._settings_dict['PORT'], self._settings_dict['IP'])
2627

2728
@staticmethod
28-
def _setup_device(type, port):
29+
def _setup_device(type, port, ip):
2930
device = None
3031
if type == 'NI':
3132
from experiments.utils.DAQ_output import DigitalModDevice
3233
device = DigitalModDevice(port)
3334

35+
if type == 'RASPBERRY':
36+
from experiments.utils.generic_output import DigitalPiBoardDevice
37+
device = DigitalPiBoardDevice(port)
38+
39+
if type == 'RASP_NETWORK':
40+
from experiments.utils.generic_output import DigitalPiBoardDevice
41+
if ip is not None:
42+
device = DigitalPiBoardDevice(port, ip)
43+
else:
44+
raise ValueError('IP required for remote GPIO control.')
45+
46+
if type == 'ARDUINO':
47+
from experiments.utils.generic_output import DigitalArduinoDevice
48+
device = DigitalArduinoDevice(port)
49+
50+
3451
return device
3552

3653
def stimulate(self):
@@ -72,23 +89,38 @@ class RewardDispenser(BaseStimulation):
7289
def __init__(self):
7390
self._name = 'RewardDispenser'
7491
self._parameter_dict = dict(TYPE = 'str',
92+
IP = 'str',
7593
STIM_PORT= 'str',
7694
REMOVAL_PORT = 'str',
7795
STIM_TIME = 'float',
7896
REMOVAL_TIME = 'float')
7997
self._settings_dict = get_stimulation_settings(self._name, self._parameter_dict)
8098
self._running = False
81-
self._stim_device = self._setup_device(self._settings_dict['TYPE'], self._settings_dict['STIM_PORT'])
82-
self._removal_device = self._setup_device(self._settings_dict['TYPE'], self._settings_dict['REMOVAL_PORT'])
99+
self._stim_device = self._setup_device(self._settings_dict['TYPE'], self._settings_dict['STIM_PORT'],
100+
self._settings_dict['IP'])
101+
self._removal_device = self._setup_device(self._settings_dict['TYPE'], self._settings_dict['REMOVAL_PORT'],
102+
self._settings_dict['IP'])
83103

84104

85105
@staticmethod
86-
def _setup_device(type, port):
106+
def _setup_device(type, port, ip):
87107
device = None
88108
if type == 'NI':
89109
from experiments.utils.DAQ_output import DigitalModDevice
90110
device = DigitalModDevice(port)
91111

112+
if type == 'RASPBERRY':
113+
from experiments.utils.generic_output import DigitialPiBoardDevice
114+
device = DigitialPiBoardDevice(port)
115+
116+
if type == 'RASP_NETWORK':
117+
from experiments.utils.generic_output import DigitialPiBoardDevice
118+
if ip is not None:
119+
device = DigitialPiBoardDevice(port, ip)
120+
else:
121+
raise ValueError('IP required for remote GPIO control.')
122+
123+
92124
return device
93125

94126
def stimulate(self):

experiments/configs/default_config.ini

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,12 +98,20 @@ STIMULATION = BaseStimulation
9898
;[STIMULATION]
9999

100100
[BaseStimulation]
101+
; can be NI, RASPBERRY or RASP_NETWORK
101102
TYPE = NI
103+
;only used in RASP_NETWORK
104+
IP = None
105+
;Port from DAQ, PIN from Raspberry
102106
PORT = Dev1/PFI6
103107
STIM_TIME = 3.5
104108

105109
[RewardDispenser]
110+
; can be NI, RASPBERRY, RASP_NETWORK or ARDUINO
106111
TYPE = NI
112+
;only used in RASP_NETWORK
113+
IP = None
114+
;Port from DAQ, PIN from Raspberry or USB PORT for ARDUINO
107115
STIM_PORT = Dev1/PFI6
108116
REMOVAL_PORT = Dev1/PFI5
109117
STIM_TIME = 3.5
Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
11
from gpiozero import DigitalOutputDevice
22
from gpiozero.pins.pigpio import PiGPIOFactory
3+
4+
import serial
35
BOARD_IP = '192.168.1.2'
46

57

68

7-
class DigitialPiBoardDevice:
9+
class DigitalPiBoardDevice:
810
"""
911
Digital modulated devices in combination with Raspberry Pi GPIO
1012
"""
1113

12-
def __init__(self, PIN, BOARD_IP: str = None, remote: bool = False):
14+
def __init__(self, PIN, BOARD_IP: str = None):
15+
1316
"""
1417
:param BOARD_IP: IP adress of board connected to the Device
1518
"""
16-
if remote:
19+
if BOARD_IP is not None:
1720
self._factory = PiGPIOFactory(host = BOARD_IP)
1821
self._device = DigitalOutputDevice(PIN= PIN, pin_factory = self._factory)
1922
else:
@@ -34,3 +37,28 @@ def toggle(self):
3437
self._running = self._device.is_active
3538

3639

40+
class DigitalArduinoDevice:
41+
"""
42+
Digital modulated devices in combination with Arduino boards connected via USB
43+
"""
44+
45+
def __init__(self, PORT):
46+
"""
47+
:param PORT: USB PORT of the arduino board
48+
"""
49+
self._device = serial.Serial(PORT, baudrate=19200)
50+
self._running = False
51+
52+
def turn_on(self):
53+
self._device.write(b'1')
54+
self._running = True
55+
56+
def turn_off(self):
57+
self._device.write(b'0')
58+
self._running = False
59+
60+
def toggle(self):
61+
if self._running:
62+
self.turn_off()
63+
else:
64+
self.turn_on()

utils/VideoAnalyzer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from utils.configloader import VIDEO_SOURCE, OUT_DIR, ANIMALS_NUMBER
1919
from experiments.custom.experiments import ExampleExperiment
2020

21+
from deepposekit.models import load_model
22+
2123

2224
def create_dataframes(data_output):
2325
"""
@@ -43,7 +45,7 @@ def create_dataframes(data_output):
4345

4446
def start_videoanalyser():
4547
print("Starting DeepLabCut")
46-
config, sess, inputs, outputs = load_deeplabcut()
48+
model = load_model(r"D:\DeepPoseKit-Data-master\datasets\fly\best_model_densenet.h5")
4749

4850
experiment_enabled = False
4951
video_output = True

utils/videoanalyzer_dpp.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
"""
2+
DeepLabStream
3+
© J.Schweihoff, M. Loshakov
4+
University Bonn Medical Faculty, Germany
5+
https://github.com/SchwarzNeuroconLab/DeepLabStream
6+
Licensed under GNU General Public License v3.0
7+
"""
8+
9+
import cv2
10+
from deepposekit.models import load_model
11+
from deepposekit.io import VideoReader
12+
import numpy as np
13+
import tensorflow as tf
14+
15+
16+
17+
def plot_dlc_bodyparts(image, bodyparts):
18+
"""
19+
Plots dlc bodyparts on given image
20+
adapted from plotter
21+
"""
22+
23+
for bp in bodyparts:
24+
center = tuple(bp.astype(int))
25+
cv2.circle(image, center=center, radius=3, color=(255, 0, 0), thickness=2)
26+
return image
27+
28+
def start_videoanalyser():
29+
print("Starting DeepPoseKit")
30+
video = cv2.VideoCapture(r"D:\DeepPoseKit-Data-master\datasets\fly\video.avi")
31+
resolution = int(video.get(cv2.CAP_PROP_FRAME_WIDTH)), int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
32+
33+
path_model = r"D:\DeepPoseKit-Data-master\datasets\fly\best_model_densenet.h5"
34+
model = load_model(path_model)
35+
36+
37+
predict_model = model.predict_model
38+
# predict_model.layers.pop(0) # remove current input layer
39+
#
40+
# inputs = tf.keras.layers.Input((resolution[0], resolution[1], 3))
41+
# outputs = predict_model(inputs)
42+
# predict_model = tf.keras.Model(inputs, outputs)
43+
44+
45+
experiment_enabled = False
46+
47+
index = 0
48+
while video.isOpened():
49+
ret, frame = video.read()
50+
if ret is not None:
51+
org_frame = frame
52+
frame = frame[..., 1][..., None]
53+
st_frame = np.stack([frame])
54+
prediction = predict_model.predict(st_frame, batch_size= 1, verbose=True)
55+
x, y, confidence = np.split(prediction, 3, -1)
56+
57+
print(prediction.shape)
58+
predi = prediction[0,:,:2]
59+
pre = predi[:, :2]
60+
print(pre)
61+
out_frame = plot_dlc_bodyparts(org_frame, predi)
62+
# out_frame = org_frame
63+
cv2.imshow('stream', out_frame)
64+
index += 1
65+
else:
66+
break
67+
68+
if cv2.waitKey(1) & 0xFF == ord('q'):
69+
break
70+
71+
video.release()
72+
73+
74+
if __name__ == "__main__":
75+
start_videoanalyser()

0 commit comments

Comments
 (0)