Skip to content

Commit e4d9f8b

Browse files
authored
Fix is list on Numpy's masked arrays (#25)
1 parent 2a1d90c commit e4d9f8b

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

src/biocutils/is_list_of_type.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
from typing import Callable, Union
22

3+
import numpy as np
4+
import numpy.ma as ma
5+
36
__author__ = "jkanche"
47
__copyright__ = "jkanche"
58
__license__ = "MIT"
69

710

8-
def is_list_of_type(
9-
x: Union[list, tuple], target_type: Callable, ignore_none: bool = False
10-
) -> bool:
11+
def is_list_of_type(x: Union[list, tuple], target_type: Callable, ignore_none: bool = False) -> bool:
1112
"""Checks if ``x`` is a list, and whether all elements of the list are of the same type.
1213
1314
Args:
@@ -19,10 +20,16 @@ def is_list_of_type(
1920
2021
Returns:
2122
True if ``x`` is a list or tuple and all elements are of the target
22-
type (or None, if ``ignore_none = True``). Otherwise, false.
23+
type (or None, if ``ignore_none = True``). Otherwise, False.
2324
"""
24-
if not isinstance(x, (list, tuple)):
25+
if not isinstance(x, (list, tuple, np.ndarray, ma.MaskedArray)):
2526
return False
27+
28+
if isinstance(x, ma.MaskedArray):
29+
if not ignore_none:
30+
return all(x.mask) and all(isinstance(item, target_type) for item in x.data)
31+
else:
32+
return all(isinstance(item, target_type) for item in x.data[x.mask])
2633

2734
if not ignore_none:
2835
return all(isinstance(item, target_type) for item in x)

tests/test_list_type_checks.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numpy as np
2+
import numpy.ma as ma
23
import pytest
34
from biocutils import is_list_of_type
45

@@ -29,3 +30,13 @@ def test_numpy_elems():
2930
x = [np.random.rand(3), np.random.rand(3, 2)]
3031

3132
assert is_list_of_type(x, np.ndarray)
33+
34+
def test_numpy_arrays():
35+
x = np.random.rand(3)
36+
37+
assert is_list_of_type(x, float)
38+
39+
x = ma.array([1, None, 3], mask = [0, 1, 0], dtype=np.float32)
40+
41+
assert is_list_of_type(x, np.float32, ignore_none=True)
42+
assert is_list_of_type(x, np.float32) is False

0 commit comments

Comments
 (0)