Skip to content

Commit 10c1fe4

Browse files
committed
feat: added get_numpy_attribute_type_from_iterable()
1 parent c33b557 commit 10c1fe4

File tree

3 files changed

+96
-1
lines changed

3 files changed

+96
-1
lines changed

src/igraph_ctypes/_internal/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def vector_fields(base_type):
4949
np_type_of_igraph_bool_t = np.bool_
5050
np_type_of_igraph_integer_t = np.int64
5151
np_type_of_igraph_real_t = np.float64
52+
np_type_of_igraph_string = np.string_
5253
np_type_of_igraph_uint_t = np.uint64
5354

5455

src/igraph_ctypes/_internal/utils.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
1+
import numpy as np
2+
13
from ctypes import byref, cast, c_char_p, c_ubyte, POINTER, sizeof
24
from functools import wraps
3-
from typing import Callable, Union
5+
from numpy.typing import DTypeLike
6+
from typing import Any, Callable, Iterable, Union
47

58
from .errors import python_exception_to_igraph_error_t
9+
from .types import (
10+
np_type_of_igraph_bool_t,
11+
np_type_of_igraph_real_t,
12+
np_type_of_igraph_string,
13+
)
614

715
__all__ = ("bytes_to_str", "get_raw_memory_view", "nop", "protect", "protect_with")
816

@@ -21,6 +29,57 @@ def bytes_to_str(
2129
return value.decode(encoding, errors)
2230

2331

32+
def get_numpy_attribute_type_from_iterable( # noqa: C901
33+
it: Iterable[Any],
34+
) -> DTypeLike:
35+
"""Determines the appropriate NumPy datatype to store all the items found in
36+
the given iterable as an attribute.
37+
38+
This means that this function basically classifies the iterable into one of
39+
the following data types:
40+
41+
- np_type_of_igraph_bool_t for Boolean attributes
42+
- np_type_of_igraph_real_t for numeric attributes
43+
- np_type_of_igraph_string for string attributes
44+
- np.object_ for any other (mixed) attribute type
45+
46+
When the iterable is empty, a numeric attribute will be assumed.
47+
"""
48+
it = iter(it)
49+
try:
50+
item = next(it)
51+
except StopIteration:
52+
# Iterable empty
53+
return np_type_of_igraph_real_t
54+
55+
best_fit: DTypeLike
56+
if isinstance(item, bool):
57+
best_fit = np_type_of_igraph_bool_t
58+
elif isinstance(item, (int, float, np.number)):
59+
best_fit = np_type_of_igraph_real_t
60+
elif isinstance(item, str):
61+
best_fit = np_type_of_igraph_string
62+
else:
63+
return np.object_
64+
65+
for item in it:
66+
if isinstance(item, bool):
67+
if best_fit == np_type_of_igraph_string:
68+
return np.object_
69+
elif isinstance(item, (int, float, np.number)):
70+
if best_fit == np_type_of_igraph_string:
71+
return np.object_
72+
else:
73+
best_fit = np_type_of_igraph_real_t
74+
elif isinstance(item, str):
75+
if best_fit != np_type_of_igraph_string:
76+
return np.object_
77+
else:
78+
return np.object_
79+
80+
return best_fit
81+
82+
2483
def get_raw_memory_view(obj):
2584
"""Returns a view into the raw bytes of a ctypes object."""
2685
return cast(byref(obj), POINTER(c_ubyte * sizeof(obj))).contents

tests/test_utils.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from numpy import inf, nan, object_
2+
3+
from igraph_ctypes._internal.types import (
4+
np_type_of_igraph_bool_t,
5+
np_type_of_igraph_string,
6+
np_type_of_igraph_real_t,
7+
)
8+
from igraph_ctypes._internal.utils import get_numpy_attribute_type_from_iterable
9+
10+
from pytest import mark
11+
12+
13+
get = get_numpy_attribute_type_from_iterable
14+
15+
16+
@mark.parametrize(
17+
("input", "expected"),
18+
[
19+
((), np_type_of_igraph_real_t),
20+
((17, -2), np_type_of_igraph_real_t),
21+
((17.5, -3.5), np_type_of_igraph_real_t),
22+
((inf, -3.5, nan), np_type_of_igraph_real_t),
23+
((False, True), np_type_of_igraph_bool_t),
24+
((False, 123, True), np_type_of_igraph_real_t),
25+
((123, False, True), np_type_of_igraph_real_t),
26+
(("spam", "ham", "bacon"), np_type_of_igraph_string),
27+
(("spam", 123), object_),
28+
((123, "spam"), object_),
29+
((None,), object_),
30+
(("spam", False), object_),
31+
((123, None), object_),
32+
],
33+
)
34+
def test_empty(input, expected):
35+
assert get(input) == expected

0 commit comments

Comments
 (0)