|
1 | | -import numpy as np |
| 1 | +import sys, importlib |
| 2 | +from packaging.version import parse as parse_version |
| 3 | +from platform import python_version |
2 | 4 | import pytest |
| 5 | +from numpy.random import default_rng, Generator |
3 | 6 |
|
4 | 7 | from ragas.run_config import RunConfig |
5 | 8 |
|
| 9 | +if parse_version(python_version()) < parse_version("3.10"): |
| 10 | + from typing import NewType, Callable |
| 11 | + RandomComparison = NewType("RandomComparison", Callable[[Generator, Generator], bool]) |
| 12 | +elif parse_version(python_version()) >= parse_version("3.10"): |
| 13 | + from typing import TypeAlias, Callable |
| 14 | + RandomComparison: TypeAlias = Callable[[Generator, Generator], bool] |
6 | 15 |
|
7 | | -def test_random_num_generator(): |
8 | | - rc = RunConfig(seed=32) |
9 | | - assert isinstance(rc.rng, np.random.Generator) |
10 | | - assert rc.rng.random() == pytest.approx(0.160, rel=1e2) |
| 16 | +@pytest.fixture(scope="function") |
| 17 | +def compare_rng() -> Callable[[Generator, Generator], bool]: |
| 18 | + """Pytest fixture wrapper to check :py:cls:`numpy.random.Generator` object equivalence. |
| 19 | +
|
| 20 | + """ |
| 21 | + def _compare_rng(rng_0:Generator, rng_1:Generator) -> bool: |
| 22 | + """Compare two :py:cls:`numpy.random.Generator`object. |
| 23 | + |
| 24 | + Args: |
| 25 | + rng_0 (numpy.random.Generator) : The first generator to compare with. |
| 26 | + rng_1 (numpy.random.Generator) : The second generator to compare with. |
| 27 | +
|
| 28 | + Returns: |
| 29 | + bool: Whether the two generators are at the same state. |
| 30 | + |
| 31 | + """ |
| 32 | + return rng_0.random() == rng_1.random() |
| 33 | + |
| 34 | + return _compare_rng |
| 35 | + |
| 36 | + |
| 37 | +@pytest.mark.parametrize( |
| 38 | + "seed, expected_equivalence", |
| 39 | + ( |
| 40 | + [42, True], |
| 41 | + [None, False], |
| 42 | + ) |
| 43 | +) |
| 44 | +def test_random_num_generator(seed, compare_rng:RandomComparison, expected_equivalence): |
| 45 | + """Check :py:mod:`numpy.random` functionality and seed behaviour control.""" |
| 46 | + rc = RunConfig(seed=seed) |
| 47 | + |
| 48 | + # Check type |
| 49 | + assert isinstance(rc.rng, Generator) |
| 50 | + |
| 51 | + # Check generated value |
| 52 | + rng = default_rng(seed=seed) |
| 53 | + assert compare_rng(rc.rng, rng) == expected_equivalence |
| 54 | + |
| 55 | + # Check generation consistency |
| 56 | + importlib.reload(sys.modules['numpy.random']) |
| 57 | + new_rc = RunConfig(seed=seed) |
| 58 | + new_rng = default_rng(seed=seed) |
| 59 | + |
| 60 | + # Put generator into the same state |
| 61 | + new_rc.rng.random() |
| 62 | + new_rng.random() |
| 63 | + |
| 64 | + # Check equivalence |
| 65 | + if expected_equivalence: |
| 66 | + assert all( |
| 67 | + list( |
| 68 | + map( |
| 69 | + compare_rng, |
| 70 | + [rc.rng, new_rc.rng], |
| 71 | + [new_rng, rng] |
| 72 | + ) |
| 73 | + ) |
| 74 | + ) |
| 75 | + else: |
| 76 | + assert all( |
| 77 | + list( |
| 78 | + map( |
| 79 | + lambda x, y:not compare_rng(x, y), |
| 80 | + [rc.rng, new_rc.rng], |
| 81 | + [new_rng, rng] |
| 82 | + ) |
| 83 | + ) |
| 84 | + ) |
0 commit comments