diff --git a/src/hdmf/common/table.py b/src/hdmf/common/table.py index f4b2cd011..9d0072156 100644 --- a/src/hdmf/common/table.py +++ b/src/hdmf/common/table.py @@ -199,7 +199,7 @@ def get(self, arg, **kwargs): if isinstance(arg, slice): indices = list(range(*arg.indices(len(self.data)))) else: - if isinstance(arg[0], bool): + if isinstance(arg[0], (bool, np.bool_)): arg = np.where(arg)[0] indices = arg ret = list() @@ -1108,17 +1108,9 @@ def __get_selection_as_dict(self, arg, df, index, exclude=None, **kwargs): return ret # if index is out of range, different errors can be generated depending on the dtype of the column # but despite the differences, raise an IndexError from that error - except ValueError as ve: + except IndexError as ie: # in h5py <2, if the column is an h5py.Dataset, a ValueError was raised # in h5py 3+, this became an IndexError - x = re.match(r"^Index \((.*)\) out of range \(.*\)$", str(ve)) - if x: - msg = ("Row index %s out of range for %s '%s' (length %d)." - % (x.groups()[0], self.__class__.__name__, self.name, len(self))) - raise IndexError(msg) from ve - else: # pragma: no cover - raise ve - except IndexError as ie: x = re.match(r"^Index \((.*)\) out of range for \(.*\)$", str(ie)) if x: msg = ("Row index %s out of range for %s '%s' (length %d)." @@ -1288,10 +1280,8 @@ def generate_html_repr(self, level: int = 0, access_code: str = "", nrows: int = inside = f"{self[:min(nrows, len(self))].to_html()}" - if len(self) == nrows + 1: - inside += "
... and 1 more row.
" - elif len(self) > nrows + 1: - inside += f"... and {len(self) - nrows} more rows.
" + if len(self) >= nrows + 1: + inside += f"... and {len(self) - nrows} more row(s).
" out += ( f'\n ' ' | foo | \nbar | \nbaz | \n
---|---|---|---|
id | \n ' - '\n | \n | \n |
... and 1 more row(s).
' ) @@ -1114,6 +1195,14 @@ def test_eq_bad_type(self): table = self.with_columns_and_data() self.assertFalse(table == container) + def test_copy(self): + table = self.with_columns_and_data() + table2 = table.copy() + self.assertTrue(table == table2) + self.assertIsNot(table, table2) + for colname in table.colnames: + self.assertTrue(getattr(table, colname) == getattr(table2, colname)) + class TestDynamicTableRoundTrip(H5RoundTripMixin, TestCase): @@ -1285,6 +1374,34 @@ def test_no_df_nested(self): with self.assertRaisesWith(ValueError, msg): dynamic_table_region.get(0, df=False, index=False) + def test_create_region_with_valid_slice_range(self): + table = self.with_columns_and_data() + region = table.create_region(name='region', region=slice(0, 2), description='test region') + self.assertEqual(region.data, [0, 1]) + + def test_create_region_with_invalid_slice_range(self): + table = self.with_columns_and_data() + msg = 'region slice slice(-1, 2, None) is out of range for this DynamicTable of length 5' + with self.assertRaisesWith(IndexError, msg): + table.create_region(name='region2', region=slice(-1, 2), description='test region') + + def test_create_region_with_none_slice(self): + table = self.with_columns_and_data() + region = table.create_region(name='region2', region=slice(0, None), description='test region') + self.assertEqual(region.data, [0, 1, 2, 3, 4]) + + def test_create_region_with_negative_index(self): + table = self.with_columns_and_data() + + msg = 'The index -1 is out of range for this DynamicTable of length 5' + with self.assertRaisesWith(IndexError, msg): + table.create_region(name='region', region=[-1, 0], description='test region') + + def test_create_region_with_out_of_range_index(self): + table = self.with_columns_and_data() + msg = 'The index 10 is out of range for this DynamicTable of length 5' + with self.assertRaisesWith(IndexError, msg): + table.create_region(name='region', region=[0, 10], description='test region') class DynamicTableRegionRoundTrip(H5RoundTripMixin, TestCase): @@ -2463,6 +2580,23 @@ def test_init_data(self): self.assertListEqual(foo_ind[0], ['a', 'b']) self.assertListEqual(foo_ind[1], ['c']) + def test_get_with_boolean(self): + """Test VectorIndex.get with boolean argument""" + data = VectorData(name='data', description='desc', data=['a', 'b', 'c', 'd', 'e']) + index = VectorIndex(name='index', data=[2, 3, 5], target=data) + result = index.get([True, False, True]) + + self.assertEqual(result, [['a', 'b',], ['d', 'e']]) + self.assertEqual(len(result), 2) + + def test_get_with_boolean_array(self): + """Test VectorIndex.get with boolean np.array argument""" + data = VectorData(name='data', description='desc', data=['a', 'b', 'c', 'd', 'e']) + index = VectorIndex(name='index', data=[2, 3, 5], target=data) + result = index.get(np.array([True, False, True])) + + self.assertEqual(result, [['a', 'b',], ['d', 'e']]) + self.assertEqual(len(result), 2) class TestDoubleIndex(TestCase): @@ -2610,6 +2744,14 @@ def test_enum_index(self): index=pd.Series(name='id', data=[0, 1, 2])) pd.testing.assert_frame_equal(exp, rec) + def test_add_column_table_and_enum_error(self): + """Test that adding a column with both table and enum raises an error.""" + table = DynamicTable(name='table0', description='an example table') + + msg = "column 'col1' cannot be both a table region and come from an enumerable set of elements" + with self.assertRaisesWith(ValueError, msg): + table.add_column(name='col1', description='test', table=True, enum=True) + class TestDynamicTableInitIndexRoundTrip(H5RoundTripMixin, TestCase):