|
3 | 3 | from unittest import mock |
4 | 4 | from functools import partial |
5 | 5 | import numpy as np |
| 6 | +import pandas as pd |
6 | 7 | import pickle |
7 | 8 | import copy |
8 | 9 |
|
9 | 10 | from iblutil.util import Bunch |
10 | 11 | from one.api import ONE |
11 | 12 |
|
| 13 | +import brainbox.behavior.dlc as dlc |
12 | 14 | import brainbox.behavior.wheel as wheel |
13 | 15 | import brainbox.behavior.training as train |
14 | 16 | from ibllib.tests import TEST_DB |
15 | 17 |
|
16 | 18 |
|
| 19 | +class TestDLC(unittest.TestCase): |
| 20 | + |
| 21 | + def setUp(self): |
| 22 | + pass |
| 23 | + |
| 24 | + def test_plt_window(self): |
| 25 | + """Test for brainbox.behavior.dlc.plt_window""" |
| 26 | + window_lag = -0.5 |
| 27 | + window_len = 2 |
| 28 | + x = 10 |
| 29 | + beg, end = dlc.plt_window(x) |
| 30 | + self.assertTrue(beg == x + window_lag, msg='Unexpected beg time') |
| 31 | + self.assertTrue(end == x + window_len, msg='Unexpected beg time') |
| 32 | + |
| 33 | + def test_insert_idx(self): |
| 34 | + """Test for brainbox.behavior.dlc.insert_idx""" |
| 35 | + |
| 36 | + # Test basic functionality with simple arrays |
| 37 | + array = np.array([1, 3, 5, 7, 9]) |
| 38 | + values = np.array([2, 6]) |
| 39 | + result = dlc.insert_idx(array, values) |
| 40 | + expected = np.array([1, 3]) |
| 41 | + np.testing.assert_array_equal(result, expected) |
| 42 | + |
| 43 | + # Test when values exactly match array elements |
| 44 | + array = np.array([1, 3, 5, 7, 9]) |
| 45 | + values = np.array([3, 7]) |
| 46 | + result = dlc.insert_idx(array, values) |
| 47 | + expected = np.array([1, 3]) |
| 48 | + np.testing.assert_array_equal(result, expected) |
| 49 | + |
| 50 | + # Test values at array boundaries |
| 51 | + array = np.array([2, 4, 6, 8, 10]) |
| 52 | + values = np.array([1, 11]) # Below first, above last |
| 53 | + result = dlc.insert_idx(array, values) |
| 54 | + expected = np.array([0, 4]) |
| 55 | + np.testing.assert_array_equal(result, expected) |
| 56 | + |
| 57 | + # Test with single value to insert |
| 58 | + array = np.array([1, 3, 5, 7, 9]) |
| 59 | + values = np.array([4]) |
| 60 | + result = dlc.insert_idx(array, values) |
| 61 | + expected = np.array([2]) |
| 62 | + np.testing.assert_array_equal(result, expected) |
| 63 | + |
| 64 | + # Test that ValueError is raised when all values map to index 0 |
| 65 | + array = np.array([10, 20, 30]) |
| 66 | + values = np.array([1, 2, 3]) # All much closer to first element |
| 67 | + with self.assertRaises(ValueError) as context: |
| 68 | + dlc.insert_idx(array, values) |
| 69 | + self.assertIn('Something is wrong, all values to insert are outside of the array', str(context.exception)) |
| 70 | + |
| 71 | + # Test with negative values in array and values |
| 72 | + array = np.array([-5, -2, 1, 4]) |
| 73 | + values = np.array([-3, 0]) |
| 74 | + result = dlc.insert_idx(array, values) |
| 75 | + expected = np.array([1, 2]) |
| 76 | + np.testing.assert_array_equal(result, expected) |
| 77 | + |
| 78 | + def test_valid_feature(self): |
| 79 | + """Test for brainbox.behavior.dlc.valid_feature""" |
| 80 | + |
| 81 | + valid = dlc.valid_feature('test_x') |
| 82 | + self.assertTrue(valid, msg='strings ending in "x" should be valid') |
| 83 | + |
| 84 | + valid = dlc.valid_feature('test_y') |
| 85 | + self.assertTrue(valid, msg='strings ending in "y" should be valid') |
| 86 | + |
| 87 | + valid = dlc.valid_feature('test_likelihood') |
| 88 | + self.assertTrue(valid, msg='strings ending in "likelihood" should be valid') |
| 89 | + |
| 90 | + valid = dlc.valid_feature('test_l') |
| 91 | + self.assertTrue(not valid, msg='only strings ending in "_x", "_y" or "_likelihood" are valid') |
| 92 | + |
| 93 | + valid = dlc.valid_feature('testlikelihood') |
| 94 | + self.assertTrue(not valid, msg='only strings ending in "_x", "_y" or "_likelihood" are valid') |
| 95 | + |
| 96 | + def test_likelihood_threshold(self): |
| 97 | + """Test for brainbox.behavior.dlc.likelihood_threshold""" |
| 98 | + |
| 99 | + dlc_data = pd.DataFrame({ |
| 100 | + 'nose_x': [10.0, 20.0, 30.0, 40.0, 50.0], |
| 101 | + 'nose_y': [15.0, 25.0, 35.0, 45.0, 55.0], |
| 102 | + 'nose_likelihood': [0.95, 0.85, 0.92, 0.88, 0.91], |
| 103 | + 'ear_x': [12.0, 22.0, 32.0, 42.0, 52.0], |
| 104 | + 'ear_y': [17.0, 27.0, 37.0, 47.0, 57.0], |
| 105 | + 'ear_likelihood': [0.99, 0.75, 0.80, 0.95, 0.60], |
| 106 | + 'tail_x': [5.0, 15.0, 25.0, 35.0, 45.0], |
| 107 | + 'tail_y': [8.0, 18.0, 28.0, 38.0, 48.0], |
| 108 | + 'tail_likelihood': [0.70, 0.95, 0.85, 0.92, 0.88] |
| 109 | + }) |
| 110 | + |
| 111 | + # Test with default threshold of 0.9 |
| 112 | + result = dlc.likelihood_threshold(dlc_data.copy()) |
| 113 | + nans = { |
| 114 | + 'nose': np.array([1, 3]), |
| 115 | + 'ear': np.array([1, 2, 4]), |
| 116 | + 'tail': np.array([0, 2, 4]), |
| 117 | + } |
| 118 | + for kp in nans.keys(): |
| 119 | + for idx in range(5): |
| 120 | + if len(np.where(nans[kp] == idx)[0]) > 0: |
| 121 | + self.assertTrue(pd.isna(result.loc[idx, f'{kp}_x'])) |
| 122 | + self.assertTrue(pd.isna(result.loc[idx, f'{kp}_y'])) |
| 123 | + else: |
| 124 | + self.assertFalse(pd.isna(result.loc[idx, f'{kp}_x'])) |
| 125 | + self.assertFalse(pd.isna(result.loc[idx, f'{kp}_y'])) |
| 126 | + |
| 127 | + # Test with custom threshold |
| 128 | + result = dlc.likelihood_threshold(dlc_data.copy(), threshold=0.8) |
| 129 | + nans = { |
| 130 | + 'nose': np.array([]), |
| 131 | + 'ear': np.array([1, 4]), |
| 132 | + 'tail': np.array([0]), |
| 133 | + } |
| 134 | + for kp in nans.keys(): |
| 135 | + for idx in range(5): |
| 136 | + if len(np.where(nans[kp] == idx)[0]) > 0: |
| 137 | + self.assertTrue(pd.isna(result.loc[idx, f'{kp}_x'])) |
| 138 | + self.assertTrue(pd.isna(result.loc[idx, f'{kp}_y'])) |
| 139 | + else: |
| 140 | + self.assertFalse(pd.isna(result.loc[idx, f'{kp}_x'])) |
| 141 | + self.assertFalse(pd.isna(result.loc[idx, f'{kp}_y'])) |
| 142 | + |
| 143 | + # Test with dataframe containing no valid features |
| 144 | + invalid_data = pd.DataFrame({ |
| 145 | + 'random_col1': [1, 2, 3], |
| 146 | + 'random_col2': [4, 5, 6] |
| 147 | + }) |
| 148 | + result = dlc.likelihood_threshold(invalid_data) |
| 149 | + # Should return unchanged dataframe |
| 150 | + pd.testing.assert_frame_equal(result, invalid_data) |
| 151 | + |
| 152 | + |
17 | 153 | class TestWheel(unittest.TestCase): |
18 | 154 |
|
19 | 155 | def setUp(self): |
|
0 commit comments