diff --git a/src/nested_pandas/nestedframe/core.py b/src/nested_pandas/nestedframe/core.py index d35d0ce1..42e523a5 100644 --- a/src/nested_pandas/nestedframe/core.py +++ b/src/nested_pandas/nestedframe/core.py @@ -255,19 +255,18 @@ def _getitem_str(self, item): else: raise KeyError(f"Column '{cleaned_item}' not found in nested columns or base columns") - def _is_key_list(self, item): + @staticmethod + def _is_key_list(item): if not is_list_like(item): return False if is_bool_dtype(item): return False - for k in item: - if not isinstance(k, str): - return False - if not self._is_known_column(k): - return False - return True + return all(isinstance(k, str) for k in item) def _getitem_list(self, item): + unknown_cols = [k for k in item if not self._is_known_column(k)] + if unknown_cols: + raise KeyError(f"{unknown_cols} not in index") non_nested_keys = [k for k in item if k in self.columns] result = super().__getitem__(non_nested_keys) components = [self._parse_hierarchical_components(k) for k in item] diff --git a/tests/nested_pandas/nestedframe/test_nestedframe.py b/tests/nested_pandas/nestedframe/test_nestedframe.py index 3ddbbcf6..77314ab0 100644 --- a/tests/nested_pandas/nestedframe/test_nestedframe.py +++ b/tests/nested_pandas/nestedframe/test_nestedframe.py @@ -1,3 +1,5 @@ +import re + import numpy as np import pandas as pd import pyarrow as pa @@ -246,14 +248,15 @@ def test_get_nested_columns_errors(): base = base.join_nested(nested, "nested") - with pytest.raises(KeyError): - base[["a", "c"]] - - with pytest.raises(KeyError): - base[["a", "nested.g"]] - - with pytest.raises(KeyError): - base[["a", "nested.a", "wrong.b"]] + # Escaping the list of columns for a strict check + with pytest.raises(KeyError, match=re.escape("['c']")): + _ = base[["a", "c"]] + with pytest.raises(KeyError, match=re.escape("['nested.g']")): + _ = base[["a", "nested.c", "nested.g"]] + with pytest.raises(KeyError, match=re.escape("['wrong.b']")): + _ = base[["a", "nested.c", "wrong.b"]] + with pytest.raises(KeyError, match=re.escape("['c', 'wrong.b']")): + _ = base[["c", "nested.c", "wrong.b"]] def test_getitem_empty_bool_array():