Skip to content

Commit f2266ac

Browse files
committed
updating model options
1 parent 295888f commit f2266ac

File tree

4 files changed

+134
-32
lines changed

4 files changed

+134
-32
lines changed

DeepLabStream.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919

2020
from utils.configloader import RESOLUTION, FRAMERATE, OUT_DIR, MODEL, MULTI_CAM, STACK_FRAMES, \
2121
ANIMALS_NUMBER, STREAMS, VIDEO, IPWEBCAM
22-
from utils.poser import load_deeplabcut, get_pose, find_local_peaks_new, calculate_skeletons
22+
from utils.poser import load_deeplabcut, get_pose, find_local_peaks_new, calculate_skeletons,\
23+
get_ma_pose, calculate_ma_skeletons, calculate_skeletons_dlc_live
2324
from utils.plotter import plot_bodyparts, plot_metadata_frame
2425

2526

@@ -264,34 +265,47 @@ def write_video(self, frames: dict, index: int):
264265
######################
265266
@staticmethod
266267
def get_pose_mp(input_q, output_q):
267-
from dlclive import DLCLive
268+
from utils.configloader import MODEL_ORIGIN
269+
from utils.poser import get_ma_pose
268270

269271
"""
270272
Process to be used for each camera/DLC stream of analysis
271273
Designed to be run in an infinite loop
272274
:param input_q: index and corresponding frame
273275
:param output_q: index and corresponding analysis
274276
"""
275-
config, sess, inputs, outputs = load_deeplabcut()
276-
while True:
277-
if input_q.full():
278-
index, frame = input_q.get()
279-
scmap, locref, pose = get_pose(frame, config, sess, inputs, outputs)
280-
peaks = find_local_peaks_new(scmap, locref, ANIMALS_NUMBER, config)
281-
output_q.put((index, peaks))
282-
283-
284-
dlc_live = DLCLive(DLC_LIVE)
285-
while True:
286-
if input_q.full():
287-
index, frame = input_q.get()
288-
if not dlc_live.is_initialized:
289-
peaks = dlc_live.init_inference(frame)
290-
else:
291-
peaks = dlc_live.get_pose(frame)
292277

293-
output_q.put((index, peaks))
278+
if MODEL_ORIGIN == 'DLC' or MODEL_ORIGIN == 'MADLC':
279+
config, sess, inputs, outputs = load_deeplabcut()
280+
while True:
281+
if input_q.full():
282+
index, frame = input_q.get()
283+
if MODEL_ORIGIN == 'DLC':
284+
scmap, locref, pose = get_pose(frame, config, sess, inputs, outputs)
285+
peaks = find_local_peaks_new(scmap, locref, ANIMALS_NUMBER, config)
286+
if MODEL_ORIGIN == 'MADLC':
287+
peaks = get_ma_pose(frame, config, sess, inputs, outputs)
288+
289+
output_q.put((index, peaks))
290+
291+
elif MODEL_ORIGIN == 'DLC-LIVE':
292+
from dlclive import DLCLive
293+
from utils.configloader import MODEL_PATH
294+
dlc_live = DLCLive(MODEL_PATH)
295+
while True:
296+
if input_q.full():
297+
index, frame = input_q.get()
298+
if not dlc_live.is_initialized:
299+
peaks = dlc_live.init_inference(frame)
300+
else:
301+
peaks = dlc_live.get_pose(frame)
302+
303+
output_q.put((index, peaks))
304+
elif MODEL_ORIGIN == 'DEEPPOSEKIT':
305+
print('Not here yet...')
294306

307+
else:
308+
raise ValueError(f'Model origin {MODEL_ORIGIN} not available.')
295309

296310
@staticmethod
297311
def create_mp_tools(devices):

settings.ini

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,21 +4,20 @@ FRAMERATE = 30
44
STREAMS = color, depth, infrared
55
OUTPUT_DIRECTORY = /Output
66
MULTIPLE_DEVICES = False
7-
CAMERA_SOURCE = 0
7+
CAMERA_SOURCE = 2
88

9-
[DeepLabCut]
10-
DLC_PATH = DLC_PATH
11-
MODEL = MODEL_NAME
9+
[Pose Estimation]
10+
MODEL_ORIGIN = DLC
11+
MODEL_PATH = MODEL_PATH
12+
MODEL_NAME = MODEL
13+
; only used in DLC-LIVE for now
14+
ALL_BODYPARTS = bp1, bp2, bp3, bp4
1215

1316
[Experiment]
14-
EXP_ORIGIN = CUSTOM/BASE
17+
EXP_ORIGIN = BASE/CUSTOM
1518
EXP_NAME = CONFIG_NAME
1619
RECORD_EXP = True
1720

1821
[Video]
1922
VIDEO_SOURCE = PATH_TO_PRERECORDED_VIDEO
2023
VIDEO = False
21-
22-
[IPWEBCAM]
23-
PORT = 5555
24-
IPWEBCAM = True

utils/configloader.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,21 @@ def get_script_path():
2626
# DeepLabCut
2727
deeplabcut_config = dict(dsc_config.items('DeepLabCut'))
2828

29+
#poseestimation
30+
MODEL_ORIGIN = dsc_config['Pose Estimation'].get('MODEL_ORIGIN')
31+
MODEL_PATH = dsc_config['Pose Estimation'].get('MODEL_PATH')
32+
MODEL_NAME = dsc_config['Pose Estimation'].get('MODEL_NAME')
33+
ALL_BODYPARTS = tuple(part for part in dsc_config['Streaming'].get('ALL_BODYPARTS').split(','))
34+
35+
36+
2937
# Streaming items
3038
try:
3139
RESOLUTION = tuple(int(part) for part in dsc_config['Streaming'].get('RESOLUTION').split(','))
3240
except ValueError:
3341
print('Incorrect resolution in config!\n'
3442
'Using default value "RESOLUTION = 848, 480"')
3543
RESOLUTION = (848, 480)
36-
MODEL = dsc_config['Streaming'].get('MODEL')
3744
FRAMERATE = dsc_config['Streaming'].getint('FRAMERATE')
3845
OUT_DIR = dsc_config['Streaming'].get('OUTPUT_DIRECTORY')
3946
STREAM = dsc_config['Streaming'].getboolean('STREAM')

utils/poser.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from scipy.ndimage.measurements import label, maximum_position
1717
from scipy.ndimage.morphology import generate_binary_structure, binary_erosion
1818
from scipy.ndimage.filters import maximum_filter
19-
from utils.configloader import deeplabcut_config
19+
from utils.configloader import deeplabcut_config, MODEL_ORIGIN
2020

2121
MODEL = deeplabcut_config['model']
2222
DLC_PATH = deeplabcut_config['dlc_path']
@@ -25,6 +25,8 @@
2525
try:
2626
import deeplabcut.pose_estimation_tensorflow.nnet.predict as predict
2727
from deeplabcut.pose_estimation_tensorflow.config import load_config
28+
from deeplabcut.pose_estimation_tensorflow.nnet import predict_multianimal
29+
2830
models_folder = 'pose_estimation_tensorflow/models/'
2931
# if not DLC 2 is not installed, try import from DLC 1 the old way
3032
except ImportError:
@@ -118,7 +120,7 @@ def find_local_peaks_new(scoremap: np.ndarray, local_reference: np.ndarray, anim
118120
return all_peaks
119121

120122

121-
def calculate_skeletons(peaks: dict, animals_number: int) -> list:
123+
def calculate_dlstream_skeletons(peaks: dict, animals_number: int) -> list:
122124
"""
123125
Creating skeletons from given peaks
124126
There could be no more skeletons than animals_number
@@ -188,6 +190,61 @@ def create_animal_skeleton(dots_cluster: tuple) -> dict:
188190

189191
return animal_skeletons
190192

193+
"""maDLC"""
194+
195+
def get_ma_pose(image, config, session, inputs, outputs):
196+
"""
197+
Gets scoremap, local reference and pose from DeepLabCut using given image
198+
Pose is most probable points for each joint, and not really used later
199+
Scoremap and local reference is essential to extract skeletons
200+
:param image: frame which would be analyzed
201+
:param config, session, inputs, outputs: DeepLabCut configuration and TensorFlow variables from load_deeplabcut()
202+
203+
:return: tuple of scoremap, local reference and pose
204+
"""
205+
scmap, locref, paf, pose = predict_multianimal.get_detectionswithcosts(image, config, session, inputs, outputs, outall=True,
206+
nms_radius=5.0,
207+
det_min_score=0.1,
208+
c_engine=False)
209+
return pose
210+
211+
def calculate_ma_skeletons(pose: dict, animals_number: int) -> list:
212+
"""
213+
Creating skeletons from given pose in maDLC
214+
There could be no more skeletons than animals_number
215+
Only unique skeletons output
216+
"""
217+
218+
def extract_to_animal_skeleton(coords):
219+
"""
220+
Creating a easy to read skeleton from dots cluster
221+
Format for each joint:
222+
{'joint_name': (x,y)}
223+
"""
224+
bodyparts = np.array(coords[0])
225+
skeletons = {}
226+
for bp in range(len(bodyparts)):
227+
for animal_num in range(animals_number):
228+
if 'Mouse'+str(animal_num+1) not in skeletons.keys():
229+
skeletons['Mouse' + str(animal_num + 1)] = {}
230+
if len(bodyparts[bp]) >= animals_number:
231+
skeletons['Mouse'+str(animal_num+1)]['bp' + str(bp + 1)] = bodyparts[bp][animal_num].astype(int)
232+
else:
233+
if animal_num < len(bodyparts[bp]):
234+
skeletons['Mouse'+str(animal_num+1)]['bp' + str(bp + 1)] = bodyparts[bp][animal_num].astype(int)
235+
else:
236+
skeletons['Mouse'+str(animal_num+1)]['bp' + str(bp + 1)] = np.array([0,0])
237+
238+
return skeletons
239+
animal_skeletons = extract_to_animal_skeleton(pose['coordinates'])
240+
# animal_skeletons = list(animal_skeletons.values())
241+
242+
return animal_skeletons
243+
244+
245+
"""DLC LIVE"""
246+
247+
191248

192249
def transform_2skeleton(pose):
193250
from utils.configloader import ALL_BODYPARTS
@@ -215,3 +272,28 @@ def calculate_skeletons_dlc_live(pose ,animals_number: int = 1) -> list:
215272
skeletons = [transform_2skeleton(pose)]
216273

217274
return skeletons
275+
276+
277+
def calculate_skeletons(peaks: dict, animals_number: int) -> list:
278+
"""
279+
Creating skeletons from given peaks
280+
There could be no more skeletons than animals_number
281+
Only unique skeletons output
282+
adaptive to chosen model origin
283+
"""
284+
285+
if MODEL_ORIGIN == 'DLC':
286+
animal_skeletons = calculate_dlstream_skeletons(peaks, animals_number)
287+
288+
elif MODEL_ORIGIN == 'MADLC':
289+
animal_skeletons = calculate_ma_skeletons(peaks, animals_number)
290+
291+
elif MODEL_ORIGIN == 'DLC-LIVE':
292+
animal_skeletons = calculate_skeletons_dlc_live(peaks, animals_number= 1)
293+
if animals_number != 1:
294+
raise ValueError('Multiple animals are currently not supported by DLC-LIVE.'
295+
' If you are using differently colored animals, please refere to the bodyparts directly.')
296+
297+
298+
299+

0 commit comments

Comments
 (0)