Skip to content

Commit b667687

Browse files
committed
added padding to handle missing instances
added functionality to remove body parts not used in SIMBA
1 parent db6e777 commit b667687

File tree

4 files changed

+10
-5
lines changed

4 files changed

+10
-5
lines changed

experiments/custom/experiments.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -422,10 +422,10 @@ def _trials(self):
422422
Defining the trials
423423
"""
424424
identification_dict = dict(active={'animal': 1
425-
, 'bp': ['bp0']
425+
, 'bp': ['bp2']
426426
}
427427
,passive = {'animal': 0
428-
, 'bp': ['bp2']
428+
, 'bp': ['bp6']
429429
}
430430
)
431431

experiments/custom/stimulus_process.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def example_protocol_run(condition_q: mp.Queue):
7474
if condition_q.full():
7575
current_trial = condition_q.get()
7676
if current_trial is not None:
77-
print('IM HEEEEERE!')
7877
show_visual_stim_img(type=current_trial, name='DlStream')
7978
#dmod_device.toggle()
8079
else:

experiments/custom/triggers.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,13 @@ def __init__(self,prob_threshold: float, class_process_pool, debug: bool = False
589589

590590
def fill_time_window(self,skeleton: dict):
591591
"""Transforms skeleton input into flat numpy array of coordinates to pass to feature extraction"""
592-
flat_values = transform_2pose(skeleton).flatten()
592+
#todo: remove bodyparts that are not used automatically
593+
key_selection = {'0_tail_tip','1_tail_tip'}
594+
skeleton_selection = {k: skeleton[k] for k in skeleton.keys() if k not in key_selection}
595+
flat_values = transform_2pose(skeleton_selection).flatten()
596+
# if not enough animals are present, padd the rest with default value "0,0"
597+
if flat_values.shape[0] < 28:
598+
flat_values = np.pad(flat_values, (0, 28-flat_values.shape[0]), 'constant', constant_values=0)
593599
# this appends the new row to the deque time_window, which will drop the "oldest" entry due to a maximum
594600
# length of time_window_len
595601
self._time_window.append(flat_values)

utils/configloader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def get_script_path():
3535
model_path_string = [str(part).strip() for part in dsc_config['Pose Estimation'].get('MODEL_PATH').split(',')]
3636
MODEL_PATH = model_path_string[0] if len(model_path_string) <= 1 else model_path_string
3737
MODEL_NAME = dsc_config['Pose Estimation'].get('MODEL_NAME')
38-
ALL_BODYPARTS = tuple(part for part in dsc_config['Pose Estimation'].get('ALL_BODYPARTS').split(','))
38+
ALL_BODYPARTS = tuple(part.strip() for part in dsc_config['Pose Estimation'].get('ALL_BODYPARTS').split(','))
3939

4040
# Streaming items
4141

0 commit comments

Comments
 (0)