Skip to content

Commit 1c62466

Browse files
Gwenn-LRGwennLeRochleroch
authored
Add tests to ragas.run_config.RunConfig rng attribute. (#1169)
# Description Previous exchanges about the `numpy.random.Generator` (#1140, #1142 and #1152) led to the need to implement some tests. Previous ones had been setup but a more precise definition of the context box might strengthen the confidence in the overall generation. # Fonctionnality Check the ability to generate a specific `numpy.random.Generator`, its fixed behaviour between generations and its consistency accross module reloads but also the actual random generation when the `seed` is set to `None`. # Solution proposed Pytest with 2 parametrized test: 1 with the `seed` fixed, the other one with `seed` set to `None`. Each test will check whether the generation is fixed, reproductible and consistent according to the seed. --------- Co-authored-by: DirtyFock <[email protected]> Co-authored-by: leroch <[email protected]>
1 parent 1bcbf20 commit 1c62466

File tree

1 file changed

+79
-5
lines changed

1 file changed

+79
-5
lines changed

tests/unit/test_run_config.py

Lines changed: 79 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,84 @@
1-
import numpy as np
1+
import sys, importlib
2+
from packaging.version import parse as parse_version
3+
from platform import python_version
24
import pytest
5+
from numpy.random import default_rng, Generator
36

47
from ragas.run_config import RunConfig
58

9+
if parse_version(python_version()) < parse_version("3.10"):
10+
from typing import NewType, Callable
11+
RandomComparison = NewType("RandomComparison", Callable[[Generator, Generator], bool])
12+
elif parse_version(python_version()) >= parse_version("3.10"):
13+
from typing import TypeAlias, Callable
14+
RandomComparison: TypeAlias = Callable[[Generator, Generator], bool]
615

7-
def test_random_num_generator():
8-
rc = RunConfig(seed=32)
9-
assert isinstance(rc.rng, np.random.Generator)
10-
assert rc.rng.random() == pytest.approx(0.160, rel=1e2)
16+
@pytest.fixture(scope="function")
17+
def compare_rng() -> Callable[[Generator, Generator], bool]:
18+
"""Pytest fixture wrapper to check :py:cls:`numpy.random.Generator` object equivalence.
19+
20+
"""
21+
def _compare_rng(rng_0:Generator, rng_1:Generator) -> bool:
22+
"""Compare two :py:cls:`numpy.random.Generator`object.
23+
24+
Args:
25+
rng_0 (numpy.random.Generator) : The first generator to compare with.
26+
rng_1 (numpy.random.Generator) : The second generator to compare with.
27+
28+
Returns:
29+
bool: Whether the two generators are at the same state.
30+
31+
"""
32+
return rng_0.random() == rng_1.random()
33+
34+
return _compare_rng
35+
36+
37+
@pytest.mark.parametrize(
38+
"seed, expected_equivalence",
39+
(
40+
[42, True],
41+
[None, False],
42+
)
43+
)
44+
def test_random_num_generator(seed, compare_rng:RandomComparison, expected_equivalence):
45+
"""Check :py:mod:`numpy.random` functionality and seed behaviour control."""
46+
rc = RunConfig(seed=seed)
47+
48+
# Check type
49+
assert isinstance(rc.rng, Generator)
50+
51+
# Check generated value
52+
rng = default_rng(seed=seed)
53+
assert compare_rng(rc.rng, rng) == expected_equivalence
54+
55+
# Check generation consistency
56+
importlib.reload(sys.modules['numpy.random'])
57+
new_rc = RunConfig(seed=seed)
58+
new_rng = default_rng(seed=seed)
59+
60+
# Put generator into the same state
61+
new_rc.rng.random()
62+
new_rng.random()
63+
64+
# Check equivalence
65+
if expected_equivalence:
66+
assert all(
67+
list(
68+
map(
69+
compare_rng,
70+
[rc.rng, new_rc.rng],
71+
[new_rng, rng]
72+
)
73+
)
74+
)
75+
else:
76+
assert all(
77+
list(
78+
map(
79+
lambda x, y:not compare_rng(x, y),
80+
[rc.rng, new_rc.rng],
81+
[new_rng, rng]
82+
)
83+
)
84+
)

0 commit comments

Comments
 (0)