Skip to content

Commit cb2803a

Browse files
committed
additional tests for gibbs
- test Gibbs.run() - test ParallelGibbs.run()
1 parent fb29627 commit cb2803a

File tree

1 file changed

+305
-21
lines changed

1 file changed

+305
-21
lines changed

basicrta/tests/test_gibbs.py

Lines changed: 305 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,326 @@
11
"""
2-
Tests for the Gibbs sampler with old-style input.
2+
Tests for the Gibbs sampler with comprehensive coverage.
33
"""
44

55
import pytest
66
import numpy as np
77
import tempfile
88
import os
9-
from basicrta.gibbs import Gibbs
9+
import pickle
10+
from unittest.mock import Mock, patch
11+
from basicrta.gibbs import Gibbs, ParallelGibbs
12+
from basicrta.tests.utils import work_in
13+
import MDAnalysis as mda
14+
15+
16+
@pytest.fixture
17+
def synthetic_timeseries():
18+
"""
19+
Generate synthetic residence times from a bi-exponential distribution.
20+
21+
Returns
22+
-------
23+
dict
24+
Dictionary containing test timeseries and expected parameters
25+
"""
26+
rng = np.random.default_rng(seed=42)
27+
n_samples = 200
28+
29+
# Create a bimodal distribution with known parameters
30+
# Fast component: rate ~2.0 (1/0.5), weight ~0.5
31+
times_short = rng.exponential(0.5, n_samples // 2)
32+
# Slow component: rate ~0.2 (1/5.0), weight ~0.5
33+
times_long = rng.exponential(5.0, n_samples // 2)
34+
35+
test_times = np.concatenate([times_short, times_long])
36+
rng.shuffle(test_times) # Mix them up
37+
38+
return {
39+
'times': test_times,
40+
'expected_components': 2,
41+
'expected_rates_approx': [2.0, 0.2], # Approximate expected rates
42+
'expected_weights_approx': [0.5, 0.5], # Approximate expected weights
43+
'n_samples': n_samples
44+
}
45+
46+
47+
@pytest.fixture
48+
def mock_contact_file(tmp_path, synthetic_timeseries):
49+
"""
50+
Create a mock contact pickle file for testing ParallelGibbs.
51+
52+
Parameters
53+
----------
54+
tmp_path : Path
55+
Pytest temporary directory fixture
56+
synthetic_timeseries : dict
57+
Synthetic timeseries data from fixture
58+
59+
Returns
60+
-------
61+
str
62+
Path to the created contact file
63+
"""
64+
from basicrta.tests.utils import make_Universe
65+
66+
# Create simple test Universe for ag1 and ag2 with resnames and resids attributes
67+
ag1_universe = make_Universe(extras=('resnames', 'resids'), size=(50, 10, 1)) # 50 atoms, 10 residues
68+
ag2_universe = make_Universe(extras=('resnames', 'resids'), size=(100, 20, 1)) # 100 atoms, 20 residues
69+
70+
# Create AtomGroups
71+
ag1 = ag1_universe.atoms[:50] # All atoms for ag1
72+
ag2 = ag2_universe.atoms[:100] # All atoms for ag2
73+
74+
# Set residue names and IDs manually for testing
75+
ag1.residues.resnames[:] = ['TRP', 'VAL', 'ALA', 'GLY', 'PHE'] * 2 # Fill with test names
76+
ag1.residues.resids[:] = [313, 314, 315, 316, 317, 318, 319, 320, 321, 322] # Test residue IDs
77+
78+
# Create contact data structure
79+
times = synthetic_timeseries['times']
80+
n_contacts = len(times)
81+
82+
# Contact format: [protein_resid, lipid_resid, frame, residence_time, contact_number]
83+
contacts = np.zeros((n_contacts, 5))
84+
contacts[:, 0] = 313 # All contacts with residue 313 (TRP)
85+
contacts[:, 1] = np.arange(n_contacts) # Different lipid residues
86+
contacts[:, 2] = np.arange(n_contacts) # Sequential frames
87+
contacts[:, 3] = times # Residence times
88+
contacts[:, 4] = np.arange(n_contacts) # Contact numbers
89+
90+
# Create metadata
91+
metadata = {
92+
'ag1': ag1,
93+
'ag2': ag2,
94+
'ts': 0.1, # timestep
95+
'top': 'test.pdb',
96+
'traj': 'test.xtc'
97+
}
98+
99+
# Create numpy array with metadata
100+
contacts_dtype = np.dtype(contacts.dtype, metadata=metadata)
101+
contacts_array = contacts.astype(contacts_dtype)
102+
103+
# Save to pickle file
104+
contact_file = tmp_path / "contacts_7.0.pkl"
105+
with open(contact_file, 'wb') as f:
106+
pickle.dump(contacts_array, f)
107+
108+
return str(contact_file)
10109

11110

12111
class TestGibbsSampler:
13-
"""Tests for Gibbs sampler with traditional input."""
112+
"""Comprehensive tests for both Gibbs sampler classes."""
14113

15-
def test_gibbs_sampler_old_style_input(self, tmp_path):
114+
@pytest.mark.parametrize("init_kwargs", [
115+
{
116+
'times': None, # Will be set from fixture
117+
'residue': 'W313',
118+
'ncomp': 2,
119+
'niter': 1000,
120+
'burnin': 5,
121+
'cutoff': 7.0,
122+
'g': 100
123+
},
124+
])
125+
def test_gibbs_run_method(self, tmp_path, synthetic_timeseries, init_kwargs):
126+
"""Test the run() method for Gibbs class with synthetic data."""
127+
128+
# Set up the times from fixture
129+
if 'times' in init_kwargs:
130+
init_kwargs['times'] = synthetic_timeseries['times']
131+
132+
# Create output directory structure
133+
output_dir = tmp_path / f"basicrta-{init_kwargs['cutoff']}" / init_kwargs['residue']
134+
output_dir.mkdir(parents=True, exist_ok=True)
135+
136+
# Change to tmp_path to avoid creating output in the repo
137+
with work_in(tmp_path):
138+
# Initialize Gibbs sampler
139+
gibbs = Gibbs(**init_kwargs)
140+
141+
# Test that initialization worked correctly
142+
assert gibbs.times is not None, "Times should be set"
143+
assert gibbs.ncomp == 2, "Should have 2 components"
144+
assert gibbs.niter == 1000, "Should have 1000 iterations"
145+
146+
# Run the sampler
147+
gibbs.run()
148+
149+
# Verify that the sampler ran successfully
150+
assert hasattr(gibbs, 'mcweights'), "Gibbs sampler should have mcweights after running"
151+
assert hasattr(gibbs, 'mcrates'), "Gibbs sampler should have mcrates after running"
152+
assert gibbs.mcweights is not None, "mcweights should not be None"
153+
assert gibbs.mcrates is not None, "mcrates should not be None"
154+
155+
# Check that we have the expected number of samples
156+
expected_samples = (gibbs.niter + 1) // gibbs.g
157+
assert gibbs.mcweights.shape[0] == expected_samples, f"Expected {expected_samples} weight samples"
158+
assert gibbs.mcrates.shape[0] == expected_samples, f"Expected {expected_samples} rate samples"
159+
160+
# Check dimensions match number of components
161+
assert gibbs.mcweights.shape[1] == gibbs.ncomp, f"Should have {gibbs.ncomp} weight components"
162+
assert gibbs.mcrates.shape[1] == gibbs.ncomp, f"Should have {gibbs.ncomp} rate components"
163+
164+
# Verify weights are properly normalized (sum to ~1 for each sample)
165+
weight_sums = np.sum(gibbs.mcweights, axis=1)
166+
np.testing.assert_allclose(weight_sums, 1.0, rtol=1e-10,
167+
err_msg="Weights should sum to 1 for each sample")
168+
169+
# Verify rates are positive
170+
assert np.all(gibbs.mcrates > 0), "All rates should be positive"
171+
172+
# Check that survival function was computed
173+
assert hasattr(gibbs, 't'), "Should have time points for survival function"
174+
assert hasattr(gibbs, 's'), "Should have survival function values"
175+
assert len(gibbs.t) > 0, "Time points should not be empty"
176+
assert len(gibbs.s) > 0, "Survival function should not be empty"
177+
178+
# Check that indicator variables were stored
179+
assert hasattr(gibbs, 'indicator'), "Should have indicator variables"
180+
assert gibbs.indicator is not None, "Indicator should not be None"
181+
assert gibbs.indicator.shape[0] == expected_samples, "Indicator should match sample count"
182+
assert gibbs.indicator.shape[1] == len(gibbs.times), "Indicator should match data count"
183+
184+
185+
def test_parallel_gibbs_initialization_and_contact_loading(self, tmp_path, mock_contact_file, synthetic_timeseries):
186+
"""Test ParallelGibbs initialization and contact file loading."""
187+
188+
# Change to tmp_path for output
189+
with work_in(tmp_path):
190+
# Initialize ParallelGibbs
191+
parallel_gibbs = ParallelGibbs(
192+
contacts=mock_contact_file,
193+
nproc=2, # 2 cores are always available ...
194+
ncomp=4,
195+
niter=1000
196+
)
197+
198+
# Test initialization
199+
assert parallel_gibbs.contacts == mock_contact_file
200+
assert parallel_gibbs.nproc == 2
201+
assert parallel_gibbs.ncomp == 4
202+
assert parallel_gibbs.niter == 1000
203+
assert parallel_gibbs.cutoff == 7.0 # Extracted from filename
204+
205+
# Test that the contact file can be loaded and processed
206+
with open(mock_contact_file, 'rb') as f:
207+
contacts = pickle.load(f)
208+
209+
# Verify the contact file structure
210+
assert contacts.shape[1] == 5, "Contact data should have 5 columns"
211+
assert len(contacts) == len(synthetic_timeseries['times']), "Should have correct number of contacts"
212+
213+
# Test contact data processing logic (without running full Gibbs)
214+
metadata = contacts.dtype.metadata
215+
assert 'ag1' in metadata, "Metadata should contain ag1"
216+
assert 'ag2' in metadata, "Metadata should contain ag2"
217+
assert 'ts' in metadata, "Metadata should contain timestep"
218+
219+
protids = np.unique(contacts[:, 0])
220+
assert 313 in protids, "Should have protein residue 313"
221+
222+
# Test that residence times are extracted correctly
223+
residue_313_times = contacts[contacts[:, 0] == 313][:, 3]
224+
assert len(residue_313_times) > 0, "Should have residence times for residue 313"
225+
assert np.allclose(residue_313_times, synthetic_timeseries['times']), "Residence times should match synthetic data"
226+
227+
228+
def test_gibbs_initialization_parameters(self, synthetic_timeseries):
229+
"""Test that Gibbs class initializes with correct parameters."""
230+
times = synthetic_timeseries['times']
231+
232+
gibbs = Gibbs(
233+
times=times,
234+
residue='TEST123',
235+
loc=0,
236+
ncomp=3,
237+
niter=5000,
238+
cutoff=8.5,
239+
g=25,
240+
burnin=500,
241+
gskip=5
242+
)
243+
244+
# Test all initialization parameters
245+
assert np.array_equal(gibbs.times, times), "Times should be stored correctly"
246+
assert gibbs.residue == 'TEST123', "Residue should be stored correctly"
247+
assert gibbs.loc == 0, "Location should be stored correctly"
248+
assert gibbs.ncomp == 3, "Number of components should be stored correctly"
249+
assert gibbs.niter == 5000, "Number of iterations should be stored correctly"
250+
assert gibbs.cutoff == 8.5, "Cutoff should be stored correctly"
251+
assert gibbs.g == 25, "Gibbs skip should be stored correctly"
252+
assert gibbs.burnin == 500, "Burnin should be stored correctly"
253+
assert gibbs.gskip == 5, "Gibbs skip should be stored correctly"
254+
255+
# Test that timestep is computed correctly
256+
assert gibbs.ts is not None, "Timestep should be computed"
257+
assert gibbs.ts > 0, "Timestep should be positive"
258+
259+
260+
def test_gibbs_prepare_method(self, synthetic_timeseries):
261+
"""Test the _prepare() method of Gibbs class."""
262+
times = synthetic_timeseries['times']
263+
264+
gibbs = Gibbs(
265+
times=times,
266+
residue='W313',
267+
ncomp=2,
268+
niter=1000,
269+
g=100
270+
)
271+
272+
# Call _prepare method
273+
gibbs._prepare()
274+
275+
# Check that survival function was computed
276+
assert hasattr(gibbs, 't'), "Should have time points"
277+
assert hasattr(gibbs, 's'), "Should have survival function"
278+
assert len(gibbs.t) > 0, "Time points should not be empty"
279+
assert len(gibbs.s) > 0, "Survival function should not be empty"
280+
assert np.all(gibbs.s >= 0), "Survival function should be non-negative"
281+
assert np.all(gibbs.s <= 1), "Survival function should be <= 1"
282+
283+
# Check that arrays were initialized with correct shapes
284+
expected_samples = (gibbs.niter + 1) // gibbs.g
285+
assert gibbs.indicator.shape == (expected_samples, len(times)), "Indicator shape should be correct"
286+
assert gibbs.mcweights.shape == (expected_samples, gibbs.ncomp), "mcweights shape should be correct"
287+
assert gibbs.mcrates.shape == (expected_samples, gibbs.ncomp), "mcrates shape should be correct"
288+
289+
# Check that hyperparameters were initialized
290+
assert hasattr(gibbs, 'whypers'), "Should have weight hyperparameters"
291+
assert hasattr(gibbs, 'rhypers'), "Should have rate hyperparameters"
292+
assert gibbs.whypers.shape == (gibbs.ncomp,), "Weight hyperparameters should have correct shape"
293+
assert gibbs.rhypers.shape == (gibbs.ncomp, 2), "Rate hyperparameters should have correct shape"
294+
295+
296+
def test_parallel_gibbs_initialization(self, mock_contact_file):
297+
"""Test ParallelGibbs initialization parameters."""
298+
parallel_gibbs = ParallelGibbs(
299+
contacts=mock_contact_file,
300+
nproc=4,
301+
ncomp=5,
302+
niter=50000
303+
)
304+
305+
assert parallel_gibbs.contacts == mock_contact_file, "Contacts file should be stored"
306+
assert parallel_gibbs.nproc == 4, "Number of processes should be stored"
307+
assert parallel_gibbs.ncomp == 5, "Number of components should be stored"
308+
assert parallel_gibbs.niter == 50000, "Number of iterations should be stored"
309+
assert parallel_gibbs.cutoff == 7.0, "Cutoff should be extracted from filename"
310+
311+
312+
def test_gibbs_sampler_old_style_input(self, tmp_path, synthetic_timeseries):
16313
"""Test Gibbs sampler with old-style (non-combined) input data."""
17314

18-
# Create simple synthetic test data for more reliable testing
19-
rng = np.random.default_rng(seed=42)
20-
n_samples = 100
21-
# Create a simple bimodal distribution
22-
times_short = rng.exponential(0.5, n_samples // 2) # Fast component
23-
times_long = rng.exponential(5.0, n_samples // 2) # Slow component
24-
test_times = np.concatenate([times_short, times_long])
25-
rng.shuffle(test_times) # Mix them up
315+
# Use synthetic timeseries from fixture
316+
test_times = synthetic_timeseries['times']
26317

27318
# Create temporary directory for output
28319
output_dir = tmp_path / "basicrta-7.0" / "test_residue"
29320
output_dir.mkdir(parents=True, exist_ok=True)
30321

31322
# Change to tmp_path to avoid creating output in the repo
32-
original_cwd = os.getcwd()
33-
os.chdir(tmp_path)
34-
35-
try:
323+
with work_in(tmp_path):
36324
# Run Gibbs sampler with old-style input
37325
gibbs = Gibbs(
38326
times=test_times,
@@ -69,8 +357,4 @@ def test_gibbs_sampler_old_style_input(self, tmp_path):
69357
err_msg="Weights should sum to 1 for each sample")
70358

71359
# Verify rates are positive
72-
assert np.all(gibbs.mcrates > 0), "All rates should be positive"
73-
74-
finally:
75-
# Restore original working directory
76-
os.chdir(original_cwd)
360+
assert np.all(gibbs.mcrates > 0), "All rates should be positive"

0 commit comments

Comments
 (0)