Skip to content

Commit 250be17

Browse files
Merge branch 'main' into exp/adaptive-optimization
2 parents 71b14c2 + a301c22 commit 250be17

File tree

2 files changed

+88
-0
lines changed

2 files changed

+88
-0
lines changed

codeflash/verification/comparator.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,19 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool: # noqa: ANN001
227227
if isinstance(orig, np.dtype):
228228
return orig == new
229229

230+
# Handle numpy random generators
231+
if isinstance(orig, np.random.Generator):
232+
# Compare the underlying BitGenerator state
233+
orig_state = orig.bit_generator.state
234+
new_state = new.bit_generator.state
235+
return comparator(orig_state, new_state, superset_obj)
236+
237+
if isinstance(orig, np.random.RandomState):
238+
# Compare the internal state
239+
orig_state = orig.get_state(legacy=False)
240+
new_state = new.get_state(legacy=False)
241+
return comparator(orig_state, new_state, superset_obj)
242+
230243
if HAS_SCIPY and isinstance(orig, scipy.sparse.spmatrix):
231244
if orig.dtype != new.dtype:
232245
return False

tests/test_comparator.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -360,6 +360,81 @@ def test_numpy():
360360
assert not comparator(a_void, c_void)
361361

362362

363+
def test_numpy_random_generator():
364+
try:
365+
import numpy as np
366+
except ImportError:
367+
pytest.skip()
368+
369+
# Test numpy.random.Generator (modern API)
370+
# Same seed should produce equal generators
371+
rng1 = np.random.default_rng(seed=42)
372+
rng2 = np.random.default_rng(seed=42)
373+
assert comparator(rng1, rng2)
374+
375+
# Different seeds should produce non-equal generators
376+
rng3 = np.random.default_rng(seed=123)
377+
assert not comparator(rng1, rng3)
378+
379+
# After generating numbers, state changes
380+
rng4 = np.random.default_rng(seed=42)
381+
rng5 = np.random.default_rng(seed=42)
382+
rng4.random() # Advance state
383+
assert not comparator(rng4, rng5)
384+
385+
# Both advanced by same amount should be equal
386+
rng5.random()
387+
assert comparator(rng4, rng5)
388+
389+
# Test with different bit generators
390+
from numpy.random import PCG64, MT19937
391+
rng_pcg1 = np.random.Generator(PCG64(seed=42))
392+
rng_pcg2 = np.random.Generator(PCG64(seed=42))
393+
assert comparator(rng_pcg1, rng_pcg2)
394+
395+
rng_mt1 = np.random.Generator(MT19937(seed=42))
396+
rng_mt2 = np.random.Generator(MT19937(seed=42))
397+
assert comparator(rng_mt1, rng_mt2)
398+
399+
# Different bit generator types should not be equal
400+
assert not comparator(rng_pcg1, rng_mt1)
401+
402+
403+
def test_numpy_random_state():
404+
try:
405+
import numpy as np
406+
except ImportError:
407+
pytest.skip()
408+
409+
# Test numpy.random.RandomState (legacy API)
410+
# Same seed should produce equal states
411+
rs1 = np.random.RandomState(seed=42)
412+
rs2 = np.random.RandomState(seed=42)
413+
assert comparator(rs1, rs2)
414+
415+
# Different seeds should produce non-equal states
416+
rs3 = np.random.RandomState(seed=123)
417+
assert not comparator(rs1, rs3)
418+
419+
# After generating numbers, state changes
420+
rs4 = np.random.RandomState(seed=42)
421+
rs5 = np.random.RandomState(seed=42)
422+
rs4.random() # Advance state
423+
assert not comparator(rs4, rs5)
424+
425+
# Both advanced by same amount should be equal
426+
rs5.random()
427+
assert comparator(rs4, rs5)
428+
429+
# Test state restoration
430+
rs6 = np.random.RandomState(seed=42)
431+
state = rs6.get_state()
432+
rs6.random() # Advance state
433+
rs7 = np.random.RandomState(seed=42)
434+
rs7.set_state(state)
435+
# rs6 advanced, rs7 restored to original state
436+
assert not comparator(rs6, rs7)
437+
363438

364439
def test_scipy():
365440
try:

0 commit comments

Comments
 (0)