Skip to content

Commit d9861d8

Browse files
Allow dataset row indexing with np.int types (#7423) (#7438)
1 parent 7af7ace commit d9861d8

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

src/datasets/formatting/formatting.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import numbers
1516
import operator
1617
from collections.abc import Iterable, Mapping, MutableMapping
1718
from functools import partial
@@ -565,7 +566,7 @@ def _check_valid_index_key(key: Union[int, slice, range, Iterable], size: int) -
565566

566567

567568
def key_to_query_type(key: Union[int, slice, range, str, Iterable]) -> str:
568-
if isinstance(key, int):
569+
if isinstance(key, numbers.Integral):
569570
return "row"
570571
elif isinstance(key, str):
571572
return "column"

tests/test_arrow_dataset.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4605,12 +4605,22 @@ async def f(batch):
46054605
assert len(out) == 1
46064606

46074607

4608+
def test_dataset_getitem_int_np_equivalence():
4609+
ds = Dataset.from_dict({"a": [0, 1, 2, 3]})
4610+
4611+
assert ds[1] == ds[np.int64(1)]
4612+
4613+
46084614
def test_dataset_getitem_raises():
46094615
ds = Dataset.from_dict({"a": [0, 1, 2, 3]})
46104616
with pytest.raises(TypeError):
46114617
ds[False]
46124618
with pytest.raises(TypeError):
46134619
ds._getitem(True)
4620+
with pytest.raises(TypeError):
4621+
ds[np.bool_(True)]
4622+
with pytest.raises(TypeError):
4623+
ds[1.0]
46144624

46154625

46164626
def test_categorical_dataset(tmpdir):

0 commit comments

Comments
 (0)