1- import sys , importlib
2- from packaging . version import parse as parse_version
1+ import importlib
2+ import sys
33from platform import python_version
4+
45import pytest
5- from numpy .random import default_rng , Generator
6+ from numpy .random import Generator , default_rng
7+ from packaging .version import parse as parse_version
68
79from ragas .run_config import RunConfig
810
911if parse_version (python_version ()) < parse_version ("3.10" ):
10- from typing import NewType , Callable
11- RandomComparison = NewType ("RandomComparison" , Callable [[Generator , Generator ], bool ])
12+ from typing import Callable , NewType
13+
14+ RandomComparison = NewType (
15+ "RandomComparison" , Callable [[Generator , Generator ], bool ]
16+ )
1217elif parse_version (python_version ()) >= parse_version ("3.10" ):
13- from typing import TypeAlias , Callable
18+ from typing import Callable , TypeAlias
19+
1420 RandomComparison : TypeAlias = Callable [[Generator , Generator ], bool ]
1521
22+
1623@pytest .fixture (scope = "function" )
1724def compare_rng () -> Callable [[Generator , Generator ], bool ]:
18- """Pytest fixture wrapper to check :py:cls:`numpy.random.Generator` object equivalence.
25+ """Pytest fixture wrapper to check :py:cls:`numpy.random.Generator` object equivalence."""
1926
20- """
21- def _compare_rng (rng_0 :Generator , rng_1 :Generator ) -> bool :
27+ def _compare_rng (rng_0 : Generator , rng_1 : Generator ) -> bool :
2228 """Compare two :py:cls:`numpy.random.Generator`object.
23-
29+
2430 Args:
2531 rng_0 (numpy.random.Generator) : The first generator to compare with.
2632 rng_1 (numpy.random.Generator) : The second generator to compare with.
2733
2834 Returns:
2935 bool: Whether the two generators are at the same state.
30-
36+
3137 """
3238 return rng_0 .random () == rng_1 .random ()
33-
39+
3440 return _compare_rng
3541
3642
@@ -39,9 +45,11 @@ def _compare_rng(rng_0:Generator, rng_1:Generator) -> bool:
3945 (
4046 [42 , True ],
4147 [None , False ],
42- )
48+ ),
4349)
44- def test_random_num_generator (seed , compare_rng :RandomComparison , expected_equivalence ):
50+ def test_random_num_generator (
51+ seed , compare_rng : RandomComparison , expected_equivalence
52+ ):
4553 """Check :py:mod:`numpy.random` functionality and seed behaviour control."""
4654 rc = RunConfig (seed = seed )
4755
@@ -53,7 +61,7 @@ def test_random_num_generator(seed, compare_rng:RandomComparison, expected_equiv
5361 assert compare_rng (rc .rng , rng ) == expected_equivalence
5462
5563 # Check generation consistency
56- importlib .reload (sys .modules [' numpy.random' ])
64+ importlib .reload (sys .modules [" numpy.random" ])
5765 new_rc = RunConfig (seed = seed )
5866 new_rng = default_rng (seed = seed )
5967
@@ -63,22 +71,14 @@ def test_random_num_generator(seed, compare_rng:RandomComparison, expected_equiv
6371
6472 # Check equivalence
6573 if expected_equivalence :
66- assert all (
67- list (
68- map (
69- compare_rng ,
70- [rc .rng , new_rc .rng ],
71- [new_rng , rng ]
72- )
73- )
74- )
74+ assert all (list (map (compare_rng , [rc .rng , new_rc .rng ], [new_rng , rng ])))
7575 else :
7676 assert all (
77- list (
78- map (
79- lambda x , y :not compare_rng (x , y ),
80- [rc .rng , new_rc .rng ],
81- [new_rng , rng ]
82- )
77+ list (
78+ map (
79+ lambda x , y : not compare_rng (x , y ),
80+ [rc .rng , new_rc .rng ],
81+ [new_rng , rng ],
8382 )
8483 )
84+ )
0 commit comments