Skip to content

Commit 8e3e712

Browse files
committed
feat: use NumPy's RNG instead of igraph's own
1 parent 65d7528 commit 8e3e712

File tree

7 files changed

+208
-0
lines changed

7 files changed

+208
-0
lines changed

src/codegen/internal_lib.py.in

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ from .types import (
4040
igraph_maxflow_stats_t,
4141
igraph_isocompat_t,
4242
igraph_plfit_result_t,
43+
igraph_rng_t,
44+
igraph_rng_type_t,
4345
)
4446

4547

@@ -356,6 +358,23 @@ igraph_es_destroy = _lib.igraph_es_destroy
356358
igraph_es_destroy.restype = None
357359
igraph_es_destroy.argtypes = [c_void_p]
358360

361+
# Random number generators
362+
363+
igraph_rng_init = _lib.igraph_rng_init
364+
igraph_rng_init.restype = handle_igraph_error_t
365+
igraph_rng_init.argtypes = [POINTER(igraph_rng_t), POINTER(igraph_rng_type_t)]
366+
367+
igraph_rng_destroy = _lib.igraph_rng_destroy
368+
igraph_rng_destroy.restype = None
369+
igraph_rng_destroy.argtypes = [c_void_p]
370+
371+
igraph_rng_default = _lib.igraph_rng_default
372+
igraph_rng_default.restype = POINTER(igraph_rng_t)
373+
374+
igraph_rng_set_default = _lib.igraph_rng_set_default
375+
igraph_rng_set_default.restype = POINTER(igraph_rng_t)
376+
igraph_rng_set_default.argtypes = [POINTER(igraph_rng_t)]
377+
359378
# Graph type
360379

361380
igraph_destroy = _lib.igraph_destroy

src/igraph_ctypes/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
from .version import __version__
22

33
from .graph import Graph
4+
from ._internal.setup import setup_igraph_library
45

56
__all__ = ("Graph", "__version__")
7+
8+
setup_igraph_library()

src/igraph_ctypes/_internal/lib.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
igraph_maxflow_stats_t,
4141
igraph_isocompat_t,
4242
igraph_plfit_result_t,
43+
igraph_rng_t,
44+
igraph_rng_type_t,
4345
)
4446

4547

@@ -356,6 +358,23 @@ def _load_igraph_c_library():
356358
igraph_es_destroy.restype = None
357359
igraph_es_destroy.argtypes = [c_void_p]
358360

361+
# Random number generators
362+
363+
igraph_rng_init = _lib.igraph_rng_init
364+
igraph_rng_init.restype = handle_igraph_error_t
365+
igraph_rng_init.argtypes = [POINTER(igraph_rng_t), POINTER(igraph_rng_type_t)]
366+
367+
igraph_rng_destroy = _lib.igraph_rng_destroy
368+
igraph_rng_destroy.restype = None
369+
igraph_rng_destroy.argtypes = [c_void_p]
370+
371+
igraph_rng_default = _lib.igraph_rng_default
372+
igraph_rng_default.restype = POINTER(igraph_rng_t)
373+
374+
igraph_rng_set_default = _lib.igraph_rng_set_default
375+
igraph_rng_set_default.restype = POINTER(igraph_rng_t)
376+
igraph_rng_set_default.argtypes = [POINTER(igraph_rng_t)]
377+
359378
# Graph type
360379

361380
igraph_destroy = _lib.igraph_destroy

src/igraph_ctypes/_internal/rng.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from ctypes import pointer
2+
from functools import partial
3+
from numpy.random import Generator, PCG64
4+
from typing import Callable, Optional
5+
6+
from .lib import igraph_rng_set_default
7+
from .types import igraph_rng_type_t, np_type_of_igraph_uint_t
8+
from .wrappers import _RNG
9+
10+
__all__ = ("NumPyRNG",)
11+
12+
13+
class NumPyRNG:
14+
"""Implementation of an igraph RNG that wraps a NumPy RNG."""
15+
16+
_generator: Generator
17+
"""The wrapped NumPy random number generator."""
18+
19+
_rng: _RNG
20+
_rng_type: igraph_rng_type_t
21+
22+
def __init__(self, generator: Generator):
23+
# TODO(ntamas): currently we assume that Generator.bit_generator is
24+
# PCG64
25+
assert isinstance(generator.bit_generator, PCG64)
26+
27+
self._generator = generator
28+
self._rng_type = igraph_rng_type_t(
29+
name=b"NumPy RNG",
30+
bits=64,
31+
init=igraph_rng_type_t.TYPES["init"](self._rng_init),
32+
destroy=igraph_rng_type_t.TYPES["destroy"](self._rng_destroy),
33+
seed=igraph_rng_type_t.TYPES["seed"](self._rng_seed),
34+
get=igraph_rng_type_t.TYPES["get"](self._rng_get),
35+
)
36+
self._rng = _RNG.create(pointer(self._rng_type))
37+
self._rng.unwrap().is_seeded = True
38+
39+
def _rng_init(self, _state):
40+
_state[0] = None
41+
return 0 # IGRAPH_SUCCESS
42+
43+
def _rng_destroy(self, rng):
44+
pass
45+
46+
def _rng_seed(self, _state, value):
47+
# Ignore, we assume that NumPy RNGs are seeded externally
48+
return 0
49+
50+
def _rng_get(self, _state):
51+
"""
52+
return self._generator.bit_generator.ctypes.next_uint64(
53+
self._generator.bit_generator.ctypes.state
54+
)
55+
"""
56+
return self._generator.integers(
57+
0, 0xFFFFFFFFFFFFFFFF, dtype=np_type_of_igraph_uint_t, endpoint=True
58+
)
59+
60+
def attach(self) -> Callable[[], None]:
61+
"""Attaches this RNG instance as igraph's default RNG.
62+
63+
Returns:
64+
a callable that can be called with no arguments to restore the
65+
RNG that was in effect before this RNG was attached.
66+
"""
67+
global _igraph_default_rng
68+
69+
old = igraph_rng_set_default(self._rng)
70+
_igraph_default_rng = self
71+
72+
return partial(igraph_rng_set_default, old)
73+
74+
75+
_igraph_default_rng: Optional[NumPyRNG] = None
76+
"""We need to keep a reference to NumPyRNG to keep the underlying low-level
77+
C objects alive, so we use an internal object in this module for that.
78+
"""
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from .rng import NumPyRNG
2+
3+
__all__ = ("setup_igraph_library",)
4+
5+
6+
def setup_igraph_library() -> None:
7+
"""Initializes the random number generator of the igraph library.
8+
9+
This function is called when the ``igraph_ctypes`` module is imported by the user.
10+
"""
11+
from numpy.random import default_rng
12+
13+
NumPyRNG(default_rng()).attach()

src/igraph_ctypes/_internal/types.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
c_int64,
1111
c_long,
1212
c_ulong,
13+
c_uint8,
14+
c_uint64,
1315
c_void_p,
16+
py_object,
1417
POINTER,
1518
Structure,
1619
Union as CUnion,
@@ -38,11 +41,15 @@ def vector_fields(base_type):
3841
c_int64 # TODO(ntamas): this depends on whether igraph is 32-bit or 64-bit
3942
)
4043
igraph_real_t = c_double
44+
igraph_uint_t = (
45+
c_uint64 # TODO(ntamas): this depends on whether igraph is 32-bit or 64-bit
46+
)
4147

4248
# TODO(ntamas): these depend on whether igraph is 32-bit or 64-bit
4349
np_type_of_igraph_bool_t = np.bool_
4450
np_type_of_igraph_integer_t = np.int64
4551
np_type_of_igraph_real_t = np.float64
52+
np_type_of_igraph_uint_t = np.uint64
4653

4754

4855
class FILE(Structure):
@@ -347,6 +354,60 @@ class igraph_plfit_result_t(Structure):
347354
]
348355

349356

357+
igraph_rng_state_t = py_object
358+
359+
360+
class igraph_rng_type_t(Structure):
361+
"""ctypes representation of an ``igraph_rng_type_t`` object"""
362+
363+
TYPES = {
364+
"init": CFUNCTYPE(igraph_error_t, POINTER(igraph_rng_state_t)),
365+
"destroy": CFUNCTYPE(None, igraph_rng_state_t),
366+
"seed": CFUNCTYPE(igraph_error_t, igraph_rng_state_t, igraph_uint_t),
367+
"get": CFUNCTYPE(igraph_uint_t, igraph_rng_state_t),
368+
}
369+
370+
_fields_ = [
371+
("name", c_char_p),
372+
("bits", c_uint8),
373+
("init", TYPES["init"]),
374+
("destroy", TYPES["destroy"]),
375+
("seed", TYPES["seed"]),
376+
("get", TYPES["get"]),
377+
(
378+
"get_int",
379+
CFUNCTYPE(
380+
igraph_integer_t, igraph_rng_state_t, igraph_integer_t, igraph_integer_t
381+
),
382+
),
383+
("get_real", CFUNCTYPE(igraph_real_t, igraph_rng_state_t)),
384+
("get_norm", CFUNCTYPE(igraph_real_t, igraph_rng_state_t)),
385+
("get_geom", CFUNCTYPE(igraph_real_t, igraph_rng_state_t, igraph_real_t)),
386+
(
387+
"get_binom",
388+
CFUNCTYPE(
389+
igraph_real_t, igraph_rng_state_t, igraph_integer_t, igraph_real_t
390+
),
391+
),
392+
("get_exp", CFUNCTYPE(igraph_real_t, igraph_rng_state_t, igraph_real_t)),
393+
(
394+
"get_gamma",
395+
CFUNCTYPE(igraph_real_t, igraph_rng_state_t, igraph_real_t, igraph_real_t),
396+
),
397+
("get_pois", CFUNCTYPE(igraph_real_t, igraph_rng_state_t, igraph_real_t)),
398+
]
399+
400+
401+
class igraph_rng_t(Structure):
402+
"""ctypes representation of an ``igraph_rng_t`` object"""
403+
404+
_fields_ = [
405+
("type", POINTER(igraph_rng_type_t)),
406+
("state", igraph_rng_state_t),
407+
("is_seeded", igraph_bool_t),
408+
]
409+
410+
350411
igraph_isocompat_t = CFUNCTYPE(
351412
igraph_bool_t,
352413
POINTER(igraph_t),

src/igraph_ctypes/_internal/wrappers.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
igraph_matrix_init,
99
igraph_matrix_int_destroy,
1010
igraph_matrix_int_init,
11+
igraph_rng_init,
12+
igraph_rng_destroy,
1113
igraph_vector_destroy,
1214
igraph_vector_init,
1315
igraph_vector_bool_destroy,
@@ -25,6 +27,7 @@
2527
igraph_es_t,
2628
igraph_matrix_t,
2729
igraph_matrix_int_t,
30+
igraph_rng_t,
2831
igraph_vector_t,
2932
igraph_vector_bool_t,
3033
igraph_vector_int_t,
@@ -38,6 +41,7 @@
3841
"_Graph",
3942
"_Matrix",
4043
"_MatrixInt",
44+
"_RNG",
4145
"_Vector",
4246
"_VectorBool",
4347
"_VectorInt",
@@ -236,3 +240,14 @@ class _Graph(create_boxed("_Graph", igraph_t, destructor=igraph_destroy)):
236240
igraph_es_t,
237241
destructor=igraph_es_destroy,
238242
)
243+
244+
245+
class _RNG(
246+
create_boxed(
247+
"_RNG",
248+
igraph_rng_t,
249+
constructor=igraph_rng_init,
250+
destructor=igraph_rng_destroy,
251+
)
252+
):
253+
pass

0 commit comments

Comments
 (0)