11"""
2- Tests for the Gibbs sampler with old-style input .
2+ Tests for the Gibbs sampler with comprehensive coverage .
33"""
44
55import pytest
66import numpy as np
77import tempfile
88import os
9- from basicrta .gibbs import Gibbs
9+ import pickle
10+ from unittest .mock import Mock , patch
11+ from basicrta .gibbs import Gibbs , ParallelGibbs
12+ from basicrta .tests .utils import work_in
13+ import MDAnalysis as mda
14+
15+
16+ @pytest .fixture
17+ def synthetic_timeseries ():
18+ """
19+ Generate synthetic residence times from a bi-exponential distribution.
20+
21+ Returns
22+ -------
23+ dict
24+ Dictionary containing test timeseries and expected parameters
25+ """
26+ rng = np .random .default_rng (seed = 42 )
27+ n_samples = 200
28+
29+ # Create a bimodal distribution with known parameters
30+ # Fast component: rate ~2.0 (1/0.5), weight ~0.5
31+ times_short = rng .exponential (0.5 , n_samples // 2 )
32+ # Slow component: rate ~0.2 (1/5.0), weight ~0.5
33+ times_long = rng .exponential (5.0 , n_samples // 2 )
34+
35+ test_times = np .concatenate ([times_short , times_long ])
36+ rng .shuffle (test_times ) # Mix them up
37+
38+ return {
39+ 'times' : test_times ,
40+ 'expected_components' : 2 ,
41+ 'expected_rates_approx' : [2.0 , 0.2 ], # Approximate expected rates
42+ 'expected_weights_approx' : [0.5 , 0.5 ], # Approximate expected weights
43+ 'n_samples' : n_samples
44+ }
45+
46+
47+ @pytest .fixture
48+ def mock_contact_file (tmp_path , synthetic_timeseries ):
49+ """
50+ Create a mock contact pickle file for testing ParallelGibbs.
51+
52+ Parameters
53+ ----------
54+ tmp_path : Path
55+ Pytest temporary directory fixture
56+ synthetic_timeseries : dict
57+ Synthetic timeseries data from fixture
58+
59+ Returns
60+ -------
61+ str
62+ Path to the created contact file
63+ """
64+ from basicrta .tests .utils import make_Universe
65+
66+ # Create simple test Universe for ag1 and ag2 with resnames and resids attributes
67+ ag1_universe = make_Universe (extras = ('resnames' , 'resids' ), size = (50 , 10 , 1 )) # 50 atoms, 10 residues
68+ ag2_universe = make_Universe (extras = ('resnames' , 'resids' ), size = (100 , 20 , 1 )) # 100 atoms, 20 residues
69+
70+ # Create AtomGroups
71+ ag1 = ag1_universe .atoms [:50 ] # All atoms for ag1
72+ ag2 = ag2_universe .atoms [:100 ] # All atoms for ag2
73+
74+ # Set residue names and IDs manually for testing
75+ ag1 .residues .resnames [:] = ['TRP' , 'VAL' , 'ALA' , 'GLY' , 'PHE' ] * 2 # Fill with test names
76+ ag1 .residues .resids [:] = [313 , 314 , 315 , 316 , 317 , 318 , 319 , 320 , 321 , 322 ] # Test residue IDs
77+
78+ # Create contact data structure
79+ times = synthetic_timeseries ['times' ]
80+ n_contacts = len (times )
81+
82+ # Contact format: [protein_resid, lipid_resid, frame, residence_time, contact_number]
83+ contacts = np .zeros ((n_contacts , 5 ))
84+ contacts [:, 0 ] = 313 # All contacts with residue 313 (TRP)
85+ contacts [:, 1 ] = np .arange (n_contacts ) # Different lipid residues
86+ contacts [:, 2 ] = np .arange (n_contacts ) # Sequential frames
87+ contacts [:, 3 ] = times # Residence times
88+ contacts [:, 4 ] = np .arange (n_contacts ) # Contact numbers
89+
90+ # Create metadata
91+ metadata = {
92+ 'ag1' : ag1 ,
93+ 'ag2' : ag2 ,
94+ 'ts' : 0.1 , # timestep
95+ 'top' : 'test.pdb' ,
96+ 'traj' : 'test.xtc'
97+ }
98+
99+ # Create numpy array with metadata
100+ contacts_dtype = np .dtype (contacts .dtype , metadata = metadata )
101+ contacts_array = contacts .astype (contacts_dtype )
102+
103+ # Save to pickle file
104+ contact_file = tmp_path / "contacts_7.0.pkl"
105+ with open (contact_file , 'wb' ) as f :
106+ pickle .dump (contacts_array , f )
107+
108+ return str (contact_file )
10109
11110
12111class TestGibbsSampler :
13- """Tests for Gibbs sampler with traditional input ."""
112+ """Comprehensive tests for both Gibbs sampler classes ."""
14113
15- def test_gibbs_sampler_old_style_input (self , tmp_path ):
114+ @pytest .mark .parametrize ("init_kwargs" , [
115+ {
116+ 'times' : None , # Will be set from fixture
117+ 'residue' : 'W313' ,
118+ 'ncomp' : 2 ,
119+ 'niter' : 1000 ,
120+ 'burnin' : 5 ,
121+ 'cutoff' : 7.0 ,
122+ 'g' : 100
123+ },
124+ ])
125+ def test_gibbs_run_method (self , tmp_path , synthetic_timeseries , init_kwargs ):
126+ """Test the run() method for Gibbs class with synthetic data."""
127+
128+ # Set up the times from fixture
129+ if 'times' in init_kwargs :
130+ init_kwargs ['times' ] = synthetic_timeseries ['times' ]
131+
132+ # Create output directory structure
133+ output_dir = tmp_path / f"basicrta-{ init_kwargs ['cutoff' ]} " / init_kwargs ['residue' ]
134+ output_dir .mkdir (parents = True , exist_ok = True )
135+
136+ # Change to tmp_path to avoid creating output in the repo
137+ with work_in (tmp_path ):
138+ # Initialize Gibbs sampler
139+ gibbs = Gibbs (** init_kwargs )
140+
141+ # Test that initialization worked correctly
142+ assert gibbs .times is not None , "Times should be set"
143+ assert gibbs .ncomp == 2 , "Should have 2 components"
144+ assert gibbs .niter == 1000 , "Should have 1000 iterations"
145+
146+ # Run the sampler
147+ gibbs .run ()
148+
149+ # Verify that the sampler ran successfully
150+ assert hasattr (gibbs , 'mcweights' ), "Gibbs sampler should have mcweights after running"
151+ assert hasattr (gibbs , 'mcrates' ), "Gibbs sampler should have mcrates after running"
152+ assert gibbs .mcweights is not None , "mcweights should not be None"
153+ assert gibbs .mcrates is not None , "mcrates should not be None"
154+
155+ # Check that we have the expected number of samples
156+ expected_samples = (gibbs .niter + 1 ) // gibbs .g
157+ assert gibbs .mcweights .shape [0 ] == expected_samples , f"Expected { expected_samples } weight samples"
158+ assert gibbs .mcrates .shape [0 ] == expected_samples , f"Expected { expected_samples } rate samples"
159+
160+ # Check dimensions match number of components
161+ assert gibbs .mcweights .shape [1 ] == gibbs .ncomp , f"Should have { gibbs .ncomp } weight components"
162+ assert gibbs .mcrates .shape [1 ] == gibbs .ncomp , f"Should have { gibbs .ncomp } rate components"
163+
164+ # Verify weights are properly normalized (sum to ~1 for each sample)
165+ weight_sums = np .sum (gibbs .mcweights , axis = 1 )
166+ np .testing .assert_allclose (weight_sums , 1.0 , rtol = 1e-10 ,
167+ err_msg = "Weights should sum to 1 for each sample" )
168+
169+ # Verify rates are positive
170+ assert np .all (gibbs .mcrates > 0 ), "All rates should be positive"
171+
172+ # Check that survival function was computed
173+ assert hasattr (gibbs , 't' ), "Should have time points for survival function"
174+ assert hasattr (gibbs , 's' ), "Should have survival function values"
175+ assert len (gibbs .t ) > 0 , "Time points should not be empty"
176+ assert len (gibbs .s ) > 0 , "Survival function should not be empty"
177+
178+ # Check that indicator variables were stored
179+ assert hasattr (gibbs , 'indicator' ), "Should have indicator variables"
180+ assert gibbs .indicator is not None , "Indicator should not be None"
181+ assert gibbs .indicator .shape [0 ] == expected_samples , "Indicator should match sample count"
182+ assert gibbs .indicator .shape [1 ] == len (gibbs .times ), "Indicator should match data count"
183+
184+
185+ def test_parallel_gibbs_initialization_and_contact_loading (self , tmp_path , mock_contact_file , synthetic_timeseries ):
186+ """Test ParallelGibbs initialization and contact file loading."""
187+
188+ # Change to tmp_path for output
189+ with work_in (tmp_path ):
190+ # Initialize ParallelGibbs
191+ parallel_gibbs = ParallelGibbs (
192+ contacts = mock_contact_file ,
193+ nproc = 2 , # 2 cores are always available ...
194+ ncomp = 4 ,
195+ niter = 1000
196+ )
197+
198+ # Test initialization
199+ assert parallel_gibbs .contacts == mock_contact_file
200+ assert parallel_gibbs .nproc == 2
201+ assert parallel_gibbs .ncomp == 4
202+ assert parallel_gibbs .niter == 1000
203+ assert parallel_gibbs .cutoff == 7.0 # Extracted from filename
204+
205+ # Test that the contact file can be loaded and processed
206+ with open (mock_contact_file , 'rb' ) as f :
207+ contacts = pickle .load (f )
208+
209+ # Verify the contact file structure
210+ assert contacts .shape [1 ] == 5 , "Contact data should have 5 columns"
211+ assert len (contacts ) == len (synthetic_timeseries ['times' ]), "Should have correct number of contacts"
212+
213+ # Test contact data processing logic (without running full Gibbs)
214+ metadata = contacts .dtype .metadata
215+ assert 'ag1' in metadata , "Metadata should contain ag1"
216+ assert 'ag2' in metadata , "Metadata should contain ag2"
217+ assert 'ts' in metadata , "Metadata should contain timestep"
218+
219+ protids = np .unique (contacts [:, 0 ])
220+ assert 313 in protids , "Should have protein residue 313"
221+
222+ # Test that residence times are extracted correctly
223+ residue_313_times = contacts [contacts [:, 0 ] == 313 ][:, 3 ]
224+ assert len (residue_313_times ) > 0 , "Should have residence times for residue 313"
225+ assert np .allclose (residue_313_times , synthetic_timeseries ['times' ]), "Residence times should match synthetic data"
226+
227+
228+ def test_gibbs_initialization_parameters (self , synthetic_timeseries ):
229+ """Test that Gibbs class initializes with correct parameters."""
230+ times = synthetic_timeseries ['times' ]
231+
232+ gibbs = Gibbs (
233+ times = times ,
234+ residue = 'TEST123' ,
235+ loc = 0 ,
236+ ncomp = 3 ,
237+ niter = 5000 ,
238+ cutoff = 8.5 ,
239+ g = 25 ,
240+ burnin = 500 ,
241+ gskip = 5
242+ )
243+
244+ # Test all initialization parameters
245+ assert np .array_equal (gibbs .times , times ), "Times should be stored correctly"
246+ assert gibbs .residue == 'TEST123' , "Residue should be stored correctly"
247+ assert gibbs .loc == 0 , "Location should be stored correctly"
248+ assert gibbs .ncomp == 3 , "Number of components should be stored correctly"
249+ assert gibbs .niter == 5000 , "Number of iterations should be stored correctly"
250+ assert gibbs .cutoff == 8.5 , "Cutoff should be stored correctly"
251+ assert gibbs .g == 25 , "Gibbs skip should be stored correctly"
252+ assert gibbs .burnin == 500 , "Burnin should be stored correctly"
253+ assert gibbs .gskip == 5 , "Gibbs skip should be stored correctly"
254+
255+ # Test that timestep is computed correctly
256+ assert gibbs .ts is not None , "Timestep should be computed"
257+ assert gibbs .ts > 0 , "Timestep should be positive"
258+
259+
260+ def test_gibbs_prepare_method (self , synthetic_timeseries ):
261+ """Test the _prepare() method of Gibbs class."""
262+ times = synthetic_timeseries ['times' ]
263+
264+ gibbs = Gibbs (
265+ times = times ,
266+ residue = 'W313' ,
267+ ncomp = 2 ,
268+ niter = 1000 ,
269+ g = 100
270+ )
271+
272+ # Call _prepare method
273+ gibbs ._prepare ()
274+
275+ # Check that survival function was computed
276+ assert hasattr (gibbs , 't' ), "Should have time points"
277+ assert hasattr (gibbs , 's' ), "Should have survival function"
278+ assert len (gibbs .t ) > 0 , "Time points should not be empty"
279+ assert len (gibbs .s ) > 0 , "Survival function should not be empty"
280+ assert np .all (gibbs .s >= 0 ), "Survival function should be non-negative"
281+ assert np .all (gibbs .s <= 1 ), "Survival function should be <= 1"
282+
283+ # Check that arrays were initialized with correct shapes
284+ expected_samples = (gibbs .niter + 1 ) // gibbs .g
285+ assert gibbs .indicator .shape == (expected_samples , len (times )), "Indicator shape should be correct"
286+ assert gibbs .mcweights .shape == (expected_samples , gibbs .ncomp ), "mcweights shape should be correct"
287+ assert gibbs .mcrates .shape == (expected_samples , gibbs .ncomp ), "mcrates shape should be correct"
288+
289+ # Check that hyperparameters were initialized
290+ assert hasattr (gibbs , 'whypers' ), "Should have weight hyperparameters"
291+ assert hasattr (gibbs , 'rhypers' ), "Should have rate hyperparameters"
292+ assert gibbs .whypers .shape == (gibbs .ncomp ,), "Weight hyperparameters should have correct shape"
293+ assert gibbs .rhypers .shape == (gibbs .ncomp , 2 ), "Rate hyperparameters should have correct shape"
294+
295+
296+ def test_parallel_gibbs_initialization (self , mock_contact_file ):
297+ """Test ParallelGibbs initialization parameters."""
298+ parallel_gibbs = ParallelGibbs (
299+ contacts = mock_contact_file ,
300+ nproc = 4 ,
301+ ncomp = 5 ,
302+ niter = 50000
303+ )
304+
305+ assert parallel_gibbs .contacts == mock_contact_file , "Contacts file should be stored"
306+ assert parallel_gibbs .nproc == 4 , "Number of processes should be stored"
307+ assert parallel_gibbs .ncomp == 5 , "Number of components should be stored"
308+ assert parallel_gibbs .niter == 50000 , "Number of iterations should be stored"
309+ assert parallel_gibbs .cutoff == 7.0 , "Cutoff should be extracted from filename"
310+
311+
312+ def test_gibbs_sampler_old_style_input (self , tmp_path , synthetic_timeseries ):
16313 """Test Gibbs sampler with old-style (non-combined) input data."""
17314
18- # Create simple synthetic test data for more reliable testing
19- rng = np .random .default_rng (seed = 42 )
20- n_samples = 100
21- # Create a simple bimodal distribution
22- times_short = rng .exponential (0.5 , n_samples // 2 ) # Fast component
23- times_long = rng .exponential (5.0 , n_samples // 2 ) # Slow component
24- test_times = np .concatenate ([times_short , times_long ])
25- rng .shuffle (test_times ) # Mix them up
315+ # Use synthetic timeseries from fixture
316+ test_times = synthetic_timeseries ['times' ]
26317
27318 # Create temporary directory for output
28319 output_dir = tmp_path / "basicrta-7.0" / "test_residue"
29320 output_dir .mkdir (parents = True , exist_ok = True )
30321
31322 # Change to tmp_path to avoid creating output in the repo
32- original_cwd = os .getcwd ()
33- os .chdir (tmp_path )
34-
35- try :
323+ with work_in (tmp_path ):
36324 # Run Gibbs sampler with old-style input
37325 gibbs = Gibbs (
38326 times = test_times ,
@@ -69,8 +357,4 @@ def test_gibbs_sampler_old_style_input(self, tmp_path):
69357 err_msg = "Weights should sum to 1 for each sample" )
70358
71359 # Verify rates are positive
72- assert np .all (gibbs .mcrates > 0 ), "All rates should be positive"
73-
74- finally :
75- # Restore original working directory
76- os .chdir (original_cwd )
360+ assert np .all (gibbs .mcrates > 0 ), "All rates should be positive"
0 commit comments