Skip to content

Commit 3c00184

Browse files
authored
Merge pull request #1003 from int-brain-lab/lpqc
PostLP task
2 parents f7afd5d + 18361ec commit 3c00184

File tree

10 files changed

+546
-132
lines changed

10 files changed

+546
-132
lines changed

brainbox/behavior/dlc.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,20 @@ def insert_idx(array, values):
4444
return idx
4545

4646

47+
def valid_feature(x: str):
48+
if x.endswith('_x') or x.endswith('_y') or x.endswith('_likelihood'):
49+
return True
50+
return False
51+
52+
4753
def likelihood_threshold(dlc, threshold=0.9):
48-
"""
49-
Set dlc points with likelihood less than threshold to nan.
54+
"""Set dlc points with likelihood less than threshold to nan.
5055
51-
FIXME Add unit test.
5256
:param dlc: dlc pqt object
5357
:param threshold: likelihood threshold
5458
:return:
5559
"""
56-
features = np.unique(['_'.join(x.split('_')[:-1]) for x in dlc.keys()])
60+
features = np.unique(['_'.join(x.split('_')[:-1]) for x in dlc.keys() if valid_feature(x)])
5761
for feat in features:
5862
nan_fill = dlc[f'{feat}_likelihood'] < threshold
5963
dlc.loc[nan_fill, (f'{feat}_x', f'{feat}_y')] = np.nan
@@ -268,7 +272,7 @@ def plot_trace_on_frame(frame, dlc_df, cam):
268272
# Threshold the dlc traces
269273
dlc_df = likelihood_threshold(dlc_df)
270274
# Features without tube
271-
features = np.unique(['_'.join(x.split('_')[:-1]) for x in dlc_df.keys() if 'tube' not in x])
275+
features = np.unique(['_'.join(x.split('_')[:-1]) for x in dlc_df.keys() if valid_feature(x) and 'tube' not in x])
272276
# Normalize the number of points across cameras
273277
dlc_df_norm = pd.DataFrame()
274278
for feat in features:

brainbox/io/one.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1551,7 +1551,7 @@ def load_wheel(self, fs=1000, corner_frequency=20, order=8, collection=None):
15511551
self.wheel = self.wheel.apply(np.float32)
15521552
self.data_info.loc[self.data_info['name'] == 'wheel', 'is_loaded'] = True
15531553

1554-
def load_pose(self, likelihood_thr=0.9, views=['left', 'right', 'body']):
1554+
def load_pose(self, likelihood_thr=0.9, views=['left', 'right', 'body'], tracker='dlc'):
15551555
"""
15561556
Function to load the pose estimation results (DLC) into SessionLoader.pose. SessionLoader.pose is a
15571557
dictionary where keys are the names of the cameras for which pose data is loaded, and values are pandas
@@ -1565,13 +1565,17 @@ def load_pose(self, likelihood_thr=0.9, views=['left', 'right', 'body']):
15651565
likelihood_thr=1. Default is 0.9
15661566
views: list
15671567
List of camera views for which to try and load data. Possible options are {'left', 'right', 'body'}
1568+
tracker : str
1569+
Tracking algorithm to load pose estimates from. Possible options are {'dlc', 'lightningPose'}
15681570
"""
15691571
# empty the dictionary so that if one loads only one view, after having loaded several, the others don't linger
1572+
tracker = 'lightningPose' if tracker in ['lp', 'litpose'] else tracker
15701573
self.pose = {}
15711574
for view in views:
1572-
pose_raw = self.one.load_object(self.eid, f'{view}Camera', attribute=['dlc', 'times'], revision=self.revision or None)
1575+
pose_raw = self.one.load_object(
1576+
self.eid, f'{view}Camera', attribute=[tracker, 'times'], revision=self.revision or None)
15731577
# Double check if video timestamps are correct length or can be fixed
1574-
times_fixed, dlc = self._check_video_timestamps(view, pose_raw['times'], pose_raw['dlc'])
1578+
times_fixed, dlc = self._check_video_timestamps(view, pose_raw['times'], pose_raw[tracker])
15751579
self.pose[f'{view}Camera'] = likelihood_threshold(dlc, likelihood_thr)
15761580
self.pose[f'{view}Camera'].insert(0, 'times', times_fixed)
15771581
self.data_info.loc[self.data_info['name'] == 'pose', 'is_loaded'] = True

brainbox/tests/test_behavior.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,153 @@
33
from unittest import mock
44
from functools import partial
55
import numpy as np
6+
import pandas as pd
67
import pickle
78
import copy
89

910
from iblutil.util import Bunch
1011
from one.api import ONE
1112

13+
import brainbox.behavior.dlc as dlc
1214
import brainbox.behavior.wheel as wheel
1315
import brainbox.behavior.training as train
1416
from ibllib.tests import TEST_DB
1517

1618

19+
class TestDLC(unittest.TestCase):
20+
21+
def setUp(self):
22+
pass
23+
24+
def test_plt_window(self):
25+
"""Test for brainbox.behavior.dlc.plt_window"""
26+
window_lag = -0.5
27+
window_len = 2
28+
x = 10
29+
beg, end = dlc.plt_window(x)
30+
self.assertTrue(beg == x + window_lag, msg='Unexpected beg time')
31+
self.assertTrue(end == x + window_len, msg='Unexpected beg time')
32+
33+
def test_insert_idx(self):
34+
"""Test for brainbox.behavior.dlc.insert_idx"""
35+
36+
# Test basic functionality with simple arrays
37+
array = np.array([1, 3, 5, 7, 9])
38+
values = np.array([2, 6])
39+
result = dlc.insert_idx(array, values)
40+
expected = np.array([1, 3])
41+
np.testing.assert_array_equal(result, expected)
42+
43+
# Test when values exactly match array elements
44+
array = np.array([1, 3, 5, 7, 9])
45+
values = np.array([3, 7])
46+
result = dlc.insert_idx(array, values)
47+
expected = np.array([1, 3])
48+
np.testing.assert_array_equal(result, expected)
49+
50+
# Test values at array boundaries
51+
array = np.array([2, 4, 6, 8, 10])
52+
values = np.array([1, 11]) # Below first, above last
53+
result = dlc.insert_idx(array, values)
54+
expected = np.array([0, 4])
55+
np.testing.assert_array_equal(result, expected)
56+
57+
# Test with single value to insert
58+
array = np.array([1, 3, 5, 7, 9])
59+
values = np.array([4])
60+
result = dlc.insert_idx(array, values)
61+
expected = np.array([2])
62+
np.testing.assert_array_equal(result, expected)
63+
64+
# Test that ValueError is raised when all values map to index 0
65+
array = np.array([10, 20, 30])
66+
values = np.array([1, 2, 3]) # All much closer to first element
67+
with self.assertRaises(ValueError) as context:
68+
dlc.insert_idx(array, values)
69+
self.assertIn('Something is wrong, all values to insert are outside of the array', str(context.exception))
70+
71+
# Test with negative values in array and values
72+
array = np.array([-5, -2, 1, 4])
73+
values = np.array([-3, 0])
74+
result = dlc.insert_idx(array, values)
75+
expected = np.array([1, 2])
76+
np.testing.assert_array_equal(result, expected)
77+
78+
def test_valid_feature(self):
79+
"""Test for brainbox.behavior.dlc.valid_feature"""
80+
81+
valid = dlc.valid_feature('test_x')
82+
self.assertTrue(valid, msg='strings ending in "x" should be valid')
83+
84+
valid = dlc.valid_feature('test_y')
85+
self.assertTrue(valid, msg='strings ending in "y" should be valid')
86+
87+
valid = dlc.valid_feature('test_likelihood')
88+
self.assertTrue(valid, msg='strings ending in "likelihood" should be valid')
89+
90+
valid = dlc.valid_feature('test_l')
91+
self.assertTrue(not valid, msg='only strings ending in "_x", "_y" or "_likelihood" are valid')
92+
93+
valid = dlc.valid_feature('testlikelihood')
94+
self.assertTrue(not valid, msg='only strings ending in "_x", "_y" or "_likelihood" are valid')
95+
96+
def test_likelihood_threshold(self):
97+
"""Test for brainbox.behavior.dlc.likelihood_threshold"""
98+
99+
dlc_data = pd.DataFrame({
100+
'nose_x': [10.0, 20.0, 30.0, 40.0, 50.0],
101+
'nose_y': [15.0, 25.0, 35.0, 45.0, 55.0],
102+
'nose_likelihood': [0.95, 0.85, 0.92, 0.88, 0.91],
103+
'ear_x': [12.0, 22.0, 32.0, 42.0, 52.0],
104+
'ear_y': [17.0, 27.0, 37.0, 47.0, 57.0],
105+
'ear_likelihood': [0.99, 0.75, 0.80, 0.95, 0.60],
106+
'tail_x': [5.0, 15.0, 25.0, 35.0, 45.0],
107+
'tail_y': [8.0, 18.0, 28.0, 38.0, 48.0],
108+
'tail_likelihood': [0.70, 0.95, 0.85, 0.92, 0.88]
109+
})
110+
111+
# Test with default threshold of 0.9
112+
result = dlc.likelihood_threshold(dlc_data.copy())
113+
nans = {
114+
'nose': np.array([1, 3]),
115+
'ear': np.array([1, 2, 4]),
116+
'tail': np.array([0, 2, 4]),
117+
}
118+
for kp in nans.keys():
119+
for idx in range(5):
120+
if len(np.where(nans[kp] == idx)[0]) > 0:
121+
self.assertTrue(pd.isna(result.loc[idx, f'{kp}_x']))
122+
self.assertTrue(pd.isna(result.loc[idx, f'{kp}_y']))
123+
else:
124+
self.assertFalse(pd.isna(result.loc[idx, f'{kp}_x']))
125+
self.assertFalse(pd.isna(result.loc[idx, f'{kp}_y']))
126+
127+
# Test with custom threshold
128+
result = dlc.likelihood_threshold(dlc_data.copy(), threshold=0.8)
129+
nans = {
130+
'nose': np.array([]),
131+
'ear': np.array([1, 4]),
132+
'tail': np.array([0]),
133+
}
134+
for kp in nans.keys():
135+
for idx in range(5):
136+
if len(np.where(nans[kp] == idx)[0]) > 0:
137+
self.assertTrue(pd.isna(result.loc[idx, f'{kp}_x']))
138+
self.assertTrue(pd.isna(result.loc[idx, f'{kp}_y']))
139+
else:
140+
self.assertFalse(pd.isna(result.loc[idx, f'{kp}_x']))
141+
self.assertFalse(pd.isna(result.loc[idx, f'{kp}_y']))
142+
143+
# Test with dataframe containing no valid features
144+
invalid_data = pd.DataFrame({
145+
'random_col1': [1, 2, 3],
146+
'random_col2': [4, 5, 6]
147+
})
148+
result = dlc.likelihood_threshold(invalid_data)
149+
# Should return unchanged dataframe
150+
pd.testing.assert_frame_equal(result, invalid_data)
151+
152+
17153
class TestWheel(unittest.TestCase):
18154

19155
def setUp(self):

ibllib/oneibl/patcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,7 @@ def patch_dataset(self, file_list, dry=False, ftp=False, force=False, **kwargs):
682682
_logger.error(f'Files: {", ".join([f.name for f in file_list])} already exist, to overwrite set force=True')
683683
return
684684

685-
response = super().patch_dataset(file_list, dry=dry, repository=self.s3_repo, ftp=False, **kwargs)
685+
response = super().patch_dataset(file_list, dry=dry, repository=self.s3_repo, ftp=False, force=force, **kwargs)
686686
# TODO in an ideal case the flatiron filerecord won't be altered when we register this dataset. This requires
687687
# changing the the alyx.data.register_view
688688
for ds in response:

0 commit comments

Comments
 (0)