Skip to content

Commit e26ac3c

Browse files
committed
fix: make sure that the NumPy RNG is called without going through Python
1 parent 145fb68 commit e26ac3c

File tree

4 files changed

+6
-55
lines changed

4 files changed

+6
-55
lines changed

src/igraph_ctypes/_internal/lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
# fmt: off
88

9-
from ctypes import cdll, c_char_p, c_double, c_int, c_void_p, POINTER
9+
from ctypes import cdll, c_bool, c_char_p, c_double, c_int, c_void_p, POINTER
1010
from ctypes.util import find_library
1111

1212
from .errors import handle_igraph_error_t

src/igraph_ctypes/_internal/rng.py

Lines changed: 3 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -25,32 +25,21 @@ class NumPyRNG:
2525
_rng_type: igraph_rng_type_t
2626

2727
def __init__(self, generator: Generator):
28-
# TODO(ntamas): currently we assume that Generator.bit_generator is
29-
# PCG64
30-
assert isinstance(generator.bit_generator, PCG64)
31-
3228
self._generator = generator
3329
self._rng_type = igraph_rng_type_t(
3430
name=b"NumPy RNG",
3531
bits=64,
3632
init=igraph_rng_type_t.TYPES["init"](self._rng_init),
3733
destroy=igraph_rng_type_t.TYPES["destroy"](self._rng_destroy),
3834
seed=igraph_rng_type_t.TYPES["seed"](self._rng_seed),
39-
get=igraph_rng_type_t.TYPES["get"](self._rng_get),
40-
get_int=igraph_rng_type_t.TYPES["get_int"](self._rng_get_int),
41-
get_real=igraph_rng_type_t.TYPES["get_real"](self._rng_get_real),
42-
get_norm=igraph_rng_type_t.TYPES["get_norm"](self._rng_get_norm),
43-
get_geom=igraph_rng_type_t.TYPES["get_geom"](self._rng_get_geom),
44-
get_binom=igraph_rng_type_t.TYPES["get_binom"](self._rng_get_binom),
45-
get_exp=igraph_rng_type_t.TYPES["get_exp"](self._rng_get_exp),
46-
get_gamma=igraph_rng_type_t.TYPES["get_gamma"](self._rng_get_gamma),
47-
get_pois=igraph_rng_type_t.TYPES["get_pois"](self._rng_get_pois),
35+
get=self._generator.bit_generator.ctypes.next_uint64,
36+
get_real=self._generator.bit_generator.ctypes.next_double,
4837
)
4938
self._rng = _RNG.create(pointer(self._rng_type))
5039
self._rng.unwrap().is_seeded = True
5140

5241
def _rng_init(self, _state):
53-
_state[0] = None
42+
_state[0] = self._generator.bit_generator.ctypes.state_address
5443
return 0 # IGRAPH_SUCCESS
5544

5645
def _rng_destroy(self, rng):
@@ -60,44 +49,6 @@ def _rng_seed(self, _state, value):
6049
# Ignore, we assume that NumPy RNGs are seeded externally
6150
return 0
6251

63-
def _rng_get(self, _state):
64-
"""
65-
return self._generator.bit_generator.ctypes.next_uint64(
66-
self._generator.bit_generator.ctypes.state
67-
)
68-
"""
69-
return self._generator.integers(
70-
0, 0xFFFFFFFFFFFFFFFF, dtype=np_type_of_igraph_uint_t, endpoint=True
71-
)
72-
73-
def _rng_get_int(self, _state, lo, hi):
74-
return self._generator.integers(
75-
lo, hi, dtype=np_type_of_igraph_integer_t, endpoint=True
76-
)
77-
78-
def _rng_get_real(self, _state):
79-
return self._generator.random(dtype=np_type_of_igraph_real_t)
80-
81-
def _rng_get_norm(self, _state):
82-
return self._generator.normal()
83-
84-
def _rng_get_geom(self, _state, p):
85-
# NumPy uses 1-based return values, igraph assumes 0-based
86-
return self._generator.geometric(p) - 1
87-
88-
def _rng_get_binom(self, _state, n, p):
89-
return self._generator.binomial(n, p)
90-
91-
def _rng_get_exp(self, _state, rate):
92-
# NumPy uses the scale parameter, igraph supplies the rate parameter
93-
return self._generator.exponential(1 / rate)
94-
95-
def _rng_get_gamma(self, _state, shape, scale):
96-
return self._generator.gamma(shape, scale)
97-
98-
def _rng_get_pois(self, _state, rate):
99-
return self._generator.poisson(rate)
100-
10152
def attach(self) -> Callable[[], None]:
10253
"""Attaches this RNG instance as igraph's default RNG.
10354

src/igraph_ctypes/_internal/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,4 +144,4 @@ def setup_igraph_library() -> None:
144144
"""
145145
_setup_error_handlers()
146146
_setup_interruption_handler()
147-
# _setup_rng()
147+
_setup_rng()

src/igraph_ctypes/_internal/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,7 @@ class igraph_plfit_result_t(Structure):
354354
]
355355

356356

357-
igraph_rng_state_t = py_object
357+
igraph_rng_state_t = c_void_p
358358

359359

360360
class igraph_rng_type_t(Structure):

0 commit comments

Comments
 (0)