Skip to content

Commit e6825a9

Browse files
committed
tests for gibbs.ParallelGibbs
1 parent 19603db commit e6825a9

File tree

1 file changed

+72
-9
lines changed

1 file changed

+72
-9
lines changed

basicrta/tests/test_gibbs.py

Lines changed: 72 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,17 +63,23 @@ def mock_contact_file(tmp_path, synthetic_timeseries):
6363
"""
6464
from basicrta.tests.utils import make_Universe
6565

66-
# Create simple test Universe for ag1 and ag2 with resnames and resids attributes
67-
ag1_universe = make_Universe(extras=('resnames', 'resids'), size=(50, 10, 1)) # 50 atoms, 10 residues
68-
ag2_universe = make_Universe(extras=('resnames', 'resids'), size=(100, 20, 1)) # 100 atoms, 20 residues
66+
# Create simple test Universe for ag1 and ag2 with resnames and resids topology attributes
67+
residue_names = ['TRP', 'VAL', 'ALA', 'GLY', 'PHE', 'LEU', 'SER', 'THR', 'ASP', 'GLU']
68+
target_resids = [313, 314, 315, 316, 317, 318, 319, 320, 321, 322]
6969

70-
# Create AtomGroups
71-
ag1 = ag1_universe.atoms[:50] # All atoms for ag1
72-
ag2 = ag2_universe.atoms[:100] # All atoms for ag2
70+
# Create universe with topology attributes and values in one call
71+
ag1_universe = make_Universe(
72+
extras={
73+
'resnames': residue_names[:10], # 10 residues
74+
'resids': target_resids[:10]
75+
},
76+
size=(50, 10, 1) # 50 atoms, 10 residues
77+
)
78+
ag2_universe = make_Universe(size=(100, 20, 1)) # 100 atoms, 20 residues (no special attributes needed)
7379

74-
# Set residue names and IDs manually for testing
75-
ag1.residues.resnames[:] = ['TRP', 'VAL', 'ALA', 'GLY', 'PHE'] * 2 # Fill with test names
76-
ag1.residues.resids[:] = [313, 314, 315, 316, 317, 318, 319, 320, 321, 322] # Test residue IDs
80+
# Create AtomGroups
81+
ag1 = ag1_universe.atoms
82+
ag2 = ag2_universe.atoms
7783

7884
# Create contact data structure
7985
times = synthetic_timeseries['times']
@@ -309,6 +315,63 @@ def test_parallel_gibbs_initialization(self, mock_contact_file):
309315
assert parallel_gibbs.cutoff == 7.0, "Cutoff should be extracted from filename"
310316

311317

318+
def test_parallel_gibbs_run_method(self, tmp_path, mock_contact_file, synthetic_timeseries):
319+
"""Test the run() method for ParallelGibbs class with real multiprocessing."""
320+
321+
# Change to tmp_path for output
322+
with work_in(tmp_path):
323+
# Initialize ParallelGibbs with smaller parameters for faster testing
324+
parallel_gibbs = ParallelGibbs(
325+
contacts=mock_contact_file,
326+
nproc=2, # Use 2 processes for testing
327+
ncomp=2, # Use fewer components for speed
328+
niter=1000 # Smaller iteration count for testing
329+
)
330+
331+
# Test initialization
332+
assert parallel_gibbs.contacts == mock_contact_file
333+
assert parallel_gibbs.nproc == 2
334+
assert parallel_gibbs.ncomp == 2
335+
assert parallel_gibbs.niter == 1000
336+
assert parallel_gibbs.cutoff == 7.0
337+
338+
# Run ParallelGibbs on residue 313 (which exists in our mock contact file)
339+
parallel_gibbs.run(run_resids=[313])
340+
341+
# Verify that the expected output directory structure was created
342+
expected_residue_dir = tmp_path / f"basicrta-{parallel_gibbs.cutoff}" / "W313"
343+
assert expected_residue_dir.exists(), f"Residue directory should be created: {expected_residue_dir}"
344+
345+
# Verify that the Gibbs sampler output file was created
346+
expected_gibbs_file = expected_residue_dir / f"gibbs_{parallel_gibbs.niter}.pkl"
347+
assert expected_gibbs_file.exists(), f"Gibbs output file should be created: {expected_gibbs_file}"
348+
349+
# Load and verify the Gibbs sampler results
350+
from basicrta.gibbs import Gibbs
351+
gibbs_result = Gibbs.load(str(expected_gibbs_file))
352+
353+
# Verify the loaded Gibbs sampler has the expected properties
354+
assert gibbs_result.residue == 'W313', "Residue should match"
355+
assert gibbs_result.ncomp == 2, "Number of components should match"
356+
assert gibbs_result.niter == 1000, "Number of iterations should match"
357+
assert gibbs_result.cutoff == 7.0, "Cutoff should match"
358+
359+
# Verify that the Gibbs sampler ran successfully
360+
assert hasattr(gibbs_result, 'mcweights'), "Should have mcweights"
361+
assert hasattr(gibbs_result, 'mcrates'), "Should have mcrates"
362+
assert np.all(gibbs_result.mcweights >= 0), "mcweights should be non-negative"
363+
assert np.all(gibbs_result.mcrates >= 0), "mcrates should be non-negative"
364+
365+
# Check that we have the expected number of samples
366+
expected_samples = (gibbs_result.niter + 1) // gibbs_result.g
367+
assert gibbs_result.mcweights.shape[0] == expected_samples, "Should have correct number of weight samples"
368+
assert gibbs_result.mcrates.shape[0] == expected_samples, "Should have correct number of rate samples"
369+
370+
# Verify that the residence times used match our synthetic data
371+
assert len(gibbs_result.times) == len(synthetic_timeseries['times']), "Should use all synthetic residence times"
372+
assert np.allclose(gibbs_result.times, synthetic_timeseries['times']), "Residence times should match synthetic data"
373+
374+
312375
def test_gibbs_sampler_old_style_input(self, tmp_path, synthetic_timeseries):
313376
"""Test Gibbs sampler with old-style (non-combined) input data."""
314377

0 commit comments

Comments
 (0)