Skip to content

Commit cdf9209

Browse files
committed
feat: use NumPy RNG for get_int(), get_real() and get_norm() as well
1 parent 8e3e712 commit cdf9209

File tree

2 files changed

+43
-22
lines changed

2 files changed

+43
-22
lines changed

src/igraph_ctypes/_internal/rng.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@
44
from typing import Callable, Optional
55

66
from .lib import igraph_rng_set_default
7-
from .types import igraph_rng_type_t, np_type_of_igraph_uint_t
7+
from .types import (
8+
igraph_rng_type_t,
9+
np_type_of_igraph_integer_t,
10+
np_type_of_igraph_real_t,
11+
np_type_of_igraph_uint_t,
12+
)
813
from .wrappers import _RNG
914

1015
__all__ = ("NumPyRNG",)
@@ -32,6 +37,10 @@ def __init__(self, generator: Generator):
3237
destroy=igraph_rng_type_t.TYPES["destroy"](self._rng_destroy),
3338
seed=igraph_rng_type_t.TYPES["seed"](self._rng_seed),
3439
get=igraph_rng_type_t.TYPES["get"](self._rng_get),
40+
get_int=igraph_rng_type_t.TYPES["get_int"](self._rng_get_int),
41+
get_real=igraph_rng_type_t.TYPES["get_real"](self._rng_get_real),
42+
get_norm=igraph_rng_type_t.TYPES["get_norm"](self._rng_get_norm),
43+
# TODO(ntamas): get_geom, get_binom, get_exp, get_gamma, get_pois
3544
)
3645
self._rng = _RNG.create(pointer(self._rng_type))
3746
self._rng.unwrap().is_seeded = True
@@ -57,6 +66,17 @@ def _rng_get(self, _state):
5766
0, 0xFFFFFFFFFFFFFFFF, dtype=np_type_of_igraph_uint_t, endpoint=True
5867
)
5968

69+
def _rng_get_int(self, _state, lo, hi):
70+
return self._generator.integers(
71+
lo, hi, dtype=np_type_of_igraph_integer_t, endpoint=True
72+
)
73+
74+
def _rng_get_real(self, _state):
75+
return self._generator.random(dtype=np_type_of_igraph_real_t)
76+
77+
def _rng_get_norm(self, _state):
78+
return self._generator.normal()
79+
6080
def attach(self) -> Callable[[], None]:
6181
"""Attaches this RNG instance as igraph's default RNG.
6282

src/igraph_ctypes/_internal/types.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,20 @@ class igraph_rng_type_t(Structure):
365365
"destroy": CFUNCTYPE(None, igraph_rng_state_t),
366366
"seed": CFUNCTYPE(igraph_error_t, igraph_rng_state_t, igraph_uint_t),
367367
"get": CFUNCTYPE(igraph_uint_t, igraph_rng_state_t),
368+
"get_int": CFUNCTYPE(
369+
igraph_integer_t, igraph_rng_state_t, igraph_integer_t, igraph_integer_t
370+
),
371+
"get_real": CFUNCTYPE(igraph_real_t, igraph_rng_state_t),
372+
"get_norm": CFUNCTYPE(igraph_real_t, igraph_rng_state_t),
373+
"get_geom": CFUNCTYPE(igraph_real_t, igraph_rng_state_t, igraph_real_t),
374+
"get_binom": CFUNCTYPE(
375+
igraph_real_t, igraph_rng_state_t, igraph_integer_t, igraph_real_t
376+
),
377+
"get_exp": CFUNCTYPE(igraph_real_t, igraph_rng_state_t, igraph_real_t),
378+
"get_gamma": CFUNCTYPE(
379+
igraph_real_t, igraph_rng_state_t, igraph_real_t, igraph_real_t
380+
),
381+
"get_pois": CFUNCTYPE(igraph_real_t, igraph_rng_state_t, igraph_real_t),
368382
}
369383

370384
_fields_ = [
@@ -374,27 +388,14 @@ class igraph_rng_type_t(Structure):
374388
("destroy", TYPES["destroy"]),
375389
("seed", TYPES["seed"]),
376390
("get", TYPES["get"]),
377-
(
378-
"get_int",
379-
CFUNCTYPE(
380-
igraph_integer_t, igraph_rng_state_t, igraph_integer_t, igraph_integer_t
381-
),
382-
),
383-
("get_real", CFUNCTYPE(igraph_real_t, igraph_rng_state_t)),
384-
("get_norm", CFUNCTYPE(igraph_real_t, igraph_rng_state_t)),
385-
("get_geom", CFUNCTYPE(igraph_real_t, igraph_rng_state_t, igraph_real_t)),
386-
(
387-
"get_binom",
388-
CFUNCTYPE(
389-
igraph_real_t, igraph_rng_state_t, igraph_integer_t, igraph_real_t
390-
),
391-
),
392-
("get_exp", CFUNCTYPE(igraph_real_t, igraph_rng_state_t, igraph_real_t)),
393-
(
394-
"get_gamma",
395-
CFUNCTYPE(igraph_real_t, igraph_rng_state_t, igraph_real_t, igraph_real_t),
396-
),
397-
("get_pois", CFUNCTYPE(igraph_real_t, igraph_rng_state_t, igraph_real_t)),
391+
("get_int", TYPES["get_int"]),
392+
("get_real", TYPES["get_real"]),
393+
("get_norm", TYPES["get_norm"]),
394+
("get_geom", TYPES["get_geom"]),
395+
("get_binom", TYPES["get_binom"]),
396+
("get_exp", TYPES["get_exp"]),
397+
("get_gamma", TYPES["get_gamma"]),
398+
("get_pois", TYPES["get_pois"]),
398399
]
399400

400401

0 commit comments

Comments
 (0)