Skip to content

Commit 18361ec

Browse files
add unit tests
1 parent 09b3b3e commit 18361ec

File tree

2 files changed

+137
-3
lines changed

2 files changed

+137
-3
lines changed

brainbox/behavior/dlc.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,8 @@ def valid_feature(x: str):
5151

5252

5353
def likelihood_threshold(dlc, threshold=0.9):
54-
"""
55-
Set dlc points with likelihood less than threshold to nan.
54+
"""Set dlc points with likelihood less than threshold to nan.
5655
57-
FIXME Add unit test.
5856
:param dlc: dlc pqt object
5957
:param threshold: likelihood threshold
6058
:return:

brainbox/tests/test_behavior.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,153 @@
33
from unittest import mock
44
from functools import partial
55
import numpy as np
6+
import pandas as pd
67
import pickle
78
import copy
89

910
from iblutil.util import Bunch
1011
from one.api import ONE
1112

13+
import brainbox.behavior.dlc as dlc
1214
import brainbox.behavior.wheel as wheel
1315
import brainbox.behavior.training as train
1416
from ibllib.tests import TEST_DB
1517

1618

19+
class TestDLC(unittest.TestCase):
20+
21+
def setUp(self):
22+
pass
23+
24+
def test_plt_window(self):
25+
"""Test for brainbox.behavior.dlc.plt_window"""
26+
window_lag = -0.5
27+
window_len = 2
28+
x = 10
29+
beg, end = dlc.plt_window(x)
30+
self.assertTrue(beg == x + window_lag, msg='Unexpected beg time')
31+
self.assertTrue(end == x + window_len, msg='Unexpected beg time')
32+
33+
def test_insert_idx(self):
34+
"""Test for brainbox.behavior.dlc.insert_idx"""
35+
36+
# Test basic functionality with simple arrays
37+
array = np.array([1, 3, 5, 7, 9])
38+
values = np.array([2, 6])
39+
result = dlc.insert_idx(array, values)
40+
expected = np.array([1, 3])
41+
np.testing.assert_array_equal(result, expected)
42+
43+
# Test when values exactly match array elements
44+
array = np.array([1, 3, 5, 7, 9])
45+
values = np.array([3, 7])
46+
result = dlc.insert_idx(array, values)
47+
expected = np.array([1, 3])
48+
np.testing.assert_array_equal(result, expected)
49+
50+
# Test values at array boundaries
51+
array = np.array([2, 4, 6, 8, 10])
52+
values = np.array([1, 11]) # Below first, above last
53+
result = dlc.insert_idx(array, values)
54+
expected = np.array([0, 4])
55+
np.testing.assert_array_equal(result, expected)
56+
57+
# Test with single value to insert
58+
array = np.array([1, 3, 5, 7, 9])
59+
values = np.array([4])
60+
result = dlc.insert_idx(array, values)
61+
expected = np.array([2])
62+
np.testing.assert_array_equal(result, expected)
63+
64+
# Test that ValueError is raised when all values map to index 0
65+
array = np.array([10, 20, 30])
66+
values = np.array([1, 2, 3]) # All much closer to first element
67+
with self.assertRaises(ValueError) as context:
68+
dlc.insert_idx(array, values)
69+
self.assertIn('Something is wrong, all values to insert are outside of the array', str(context.exception))
70+
71+
# Test with negative values in array and values
72+
array = np.array([-5, -2, 1, 4])
73+
values = np.array([-3, 0])
74+
result = dlc.insert_idx(array, values)
75+
expected = np.array([1, 2])
76+
np.testing.assert_array_equal(result, expected)
77+
78+
def test_valid_feature(self):
79+
"""Test for brainbox.behavior.dlc.valid_feature"""
80+
81+
valid = dlc.valid_feature('test_x')
82+
self.assertTrue(valid, msg='strings ending in "x" should be valid')
83+
84+
valid = dlc.valid_feature('test_y')
85+
self.assertTrue(valid, msg='strings ending in "y" should be valid')
86+
87+
valid = dlc.valid_feature('test_likelihood')
88+
self.assertTrue(valid, msg='strings ending in "likelihood" should be valid')
89+
90+
valid = dlc.valid_feature('test_l')
91+
self.assertTrue(not valid, msg='only strings ending in "_x", "_y" or "_likelihood" are valid')
92+
93+
valid = dlc.valid_feature('testlikelihood')
94+
self.assertTrue(not valid, msg='only strings ending in "_x", "_y" or "_likelihood" are valid')
95+
96+
def test_likelihood_threshold(self):
97+
"""Test for brainbox.behavior.dlc.likelihood_threshold"""
98+
99+
dlc_data = pd.DataFrame({
100+
'nose_x': [10.0, 20.0, 30.0, 40.0, 50.0],
101+
'nose_y': [15.0, 25.0, 35.0, 45.0, 55.0],
102+
'nose_likelihood': [0.95, 0.85, 0.92, 0.88, 0.91],
103+
'ear_x': [12.0, 22.0, 32.0, 42.0, 52.0],
104+
'ear_y': [17.0, 27.0, 37.0, 47.0, 57.0],
105+
'ear_likelihood': [0.99, 0.75, 0.80, 0.95, 0.60],
106+
'tail_x': [5.0, 15.0, 25.0, 35.0, 45.0],
107+
'tail_y': [8.0, 18.0, 28.0, 38.0, 48.0],
108+
'tail_likelihood': [0.70, 0.95, 0.85, 0.92, 0.88]
109+
})
110+
111+
# Test with default threshold of 0.9
112+
result = dlc.likelihood_threshold(dlc_data.copy())
113+
nans = {
114+
'nose': np.array([1, 3]),
115+
'ear': np.array([1, 2, 4]),
116+
'tail': np.array([0, 2, 4]),
117+
}
118+
for kp in nans.keys():
119+
for idx in range(5):
120+
if len(np.where(nans[kp] == idx)[0]) > 0:
121+
self.assertTrue(pd.isna(result.loc[idx, f'{kp}_x']))
122+
self.assertTrue(pd.isna(result.loc[idx, f'{kp}_y']))
123+
else:
124+
self.assertFalse(pd.isna(result.loc[idx, f'{kp}_x']))
125+
self.assertFalse(pd.isna(result.loc[idx, f'{kp}_y']))
126+
127+
# Test with custom threshold
128+
result = dlc.likelihood_threshold(dlc_data.copy(), threshold=0.8)
129+
nans = {
130+
'nose': np.array([]),
131+
'ear': np.array([1, 4]),
132+
'tail': np.array([0]),
133+
}
134+
for kp in nans.keys():
135+
for idx in range(5):
136+
if len(np.where(nans[kp] == idx)[0]) > 0:
137+
self.assertTrue(pd.isna(result.loc[idx, f'{kp}_x']))
138+
self.assertTrue(pd.isna(result.loc[idx, f'{kp}_y']))
139+
else:
140+
self.assertFalse(pd.isna(result.loc[idx, f'{kp}_x']))
141+
self.assertFalse(pd.isna(result.loc[idx, f'{kp}_y']))
142+
143+
# Test with dataframe containing no valid features
144+
invalid_data = pd.DataFrame({
145+
'random_col1': [1, 2, 3],
146+
'random_col2': [4, 5, 6]
147+
})
148+
result = dlc.likelihood_threshold(invalid_data)
149+
# Should return unchanged dataframe
150+
pd.testing.assert_frame_equal(result, invalid_data)
151+
152+
17153
class TestWheel(unittest.TestCase):
18154

19155
def setUp(self):

0 commit comments

Comments
 (0)