Skip to content

Commit 77b364e

Browse files
authored
Merge pull request #49 from Becksteinlab/fix-gskip-processing
fix handling of gskip in cluster-py --gskip
2 parents a41b315 + 79798f0 commit 77b364e

File tree

6 files changed

+111
-19
lines changed

6 files changed

+111
-19
lines changed

CHANGELOG.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,25 @@ The rules for this file:
1515
* YYYY-MM-DD date format (following ISO 8601)
1616
* accompany each entry with github issue/PR number (Issue #xyz)
1717
-->
18+
## [1.1.3] - UNRELEASES
19+
20+
### Authors
21+
* @orbeckst
22+
* @rjoshi44
23+
24+
### Fixed
25+
* Fixed setting of gskip in ProcessProtein/cluster.py command line interface:
26+
set the default to 100 (as in the paper) and ensure that the correct value
27+
is used as Gibbs.gskip (which is relative to the save skip step of Gibbs.g)
28+
(Issue #48)
29+
30+
### Changed
31+
* Default kwargs for the skipping in the Gibbs sampler are now
32+
gibbs.Gibbs(g=100, gskip=1) (used to be g=50, gskip=2) but for most users
33+
gskip for processing data is not important and it makes more sense to focus
34+
on g as the stride at which we sample AND process data (#48, PR #49)
35+
36+
1837
## [1.1.2] - 2025-07-22
1938

2039
### Authors

basicrta/cluster.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import gc
3+
import warnings
34
import numpy as np
45
from tqdm import tqdm
56
from multiprocessing import Pool, Lock
@@ -29,17 +30,21 @@ class ProcessProtein(object):
2930
:param cutoff: Cutoff used in contact analysis.
3031
:type cutoff: float
3132
:param gskip: Gibbs skip parameter for decorrelated samples;
33+
only save every `gskip` samples from full Gibbs sampler chain;
3234
default from https://pubs.acs.org/doi/10.1021/acs.jctc.4c01522
35+
When the sampled Markov chain is loaded, then the output is already
36+
saved at every `Gibbs.g` samples. We calculate a new `gskip` value to
37+
get close to the desired `gskip` value.
3338
:type gskip: int
34-
:param burnin: Burn-in parameter, drop first N samples as equilibration;
39+
:param burnin: Burn-in parameter, drop first `burnin` samples as equilibration;
3540
default from https://pubs.acs.org/doi/10.1021/acs.jctc.4c01522
3641
:type burnin: int
3742
"""
3843

3944
def __init__(self, niter, prot, cutoff,
40-
gskip=1000, burnin=10000,
45+
gskip=100, burnin=10000,
4146
taus=None, bars=None):
42-
self.residues = Results() # TODO: double-check that we need to use this, it gets set in reprocess/get_taus
47+
self.residues = None
4348
self.niter = niter
4449
self.prot = prot
4550
self.cutoff = cutoff
@@ -55,16 +60,23 @@ def _single_residue(self, adir, process=False):
5560
if os.path.exists(f'{adir}/gibbs_{self.niter}.pkl'):
5661
result = f'{adir}/gibbs_{self.niter}.pkl'
5762
try:
58-
result = f'{adir}/gibbs_{self.niter}.pkl'
5963
g = Gibbs().load(result)
64+
except:
65+
result = None
66+
tau = [0, 0, 0]
67+
else:
6068
if process:
61-
g.gskip = self.gskip
69+
# calculate the new g.gskip value:
70+
ggskip = self.gskip // g.g
71+
if ggskip < 1:
72+
ggskip = 1
73+
warnings.warn(f"WARNING: gskip={self.gskip} is less than g={g.g}, setting gskip to 1")
74+
# NOTE: Gibbs samples are saved every g.g steps, then sub-sampled by g.gskip
75+
# Total skip interval = g.g * g.gskip, giving niter // (g.g * g.gskip) independent samples
76+
g.gskip = ggskip # process every g.g * g.gskip samples from full chain
6277
g.burnin = self.burnin
6378
g.process_gibbs()
6479
tau = g.estimate_tau()
65-
except:
66-
result = None
67-
tau = [0, 0, 0]
6880
else:
6981
result = None
7082
tau = [0, 0, 0]
@@ -228,7 +240,7 @@ def b_color_structure(self, structure):
228240
'LABEL-CUTOFF * <tau>. ')
229241
parser.add_argument('--structure', type=str, nargs='?')
230242
# use for default values
231-
parser.add_argument('--gskip', type=int, default=1000,
243+
parser.add_argument('--gskip', type=int, default=100,
232244
help='Gibbs skip parameter for decorrelated samples;'
233245
'default from https://pubs.acs.org/doi/10.1021/acs.jctc.4c01522')
234246
parser.add_argument('--burnin', type=int, default=10000,
@@ -240,6 +252,5 @@ def b_color_structure(self, structure):
240252
pp = ProcessProtein(args.niter, args.prot, args.cutoff,
241253
gskip=args.gskip, burnin=args.burnin)
242254
pp.reprocess(nproc=args.nproc)
243-
pp.get_taus()
244255
pp.write_data()
245256
pp.plot_protein(label_cutoff=args.label_cutoff)

basicrta/gibbs.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,23 @@ class Gibbs(object):
116116
directory to load/save results. Allows for multiple cutoffs
117117
to be tested in directory containing contacts.
118118
:type cutoff: float
119+
:param g: Gibbs skip parameter for decorrelated samples;
120+
only save every `g` samples from full Gibbs sampler chain;
121+
default from https://pubs.acs.org/doi/10.1021/acs.jctc.4c01522
122+
(NOTE: this value is called *gskip* in cluster.py)
123+
:type g: int
124+
:param burnin: Burn-in parameter, drop first `burnin` samples as equilibration;
125+
default from https://pubs.acs.org/doi/10.1021/acs.jctc.4c01522
126+
:type burnin: int
127+
:param gskip: Process data from the subsampled chain (ever `g` samples) at a
128+
coarser skip interval of `gskip` samples. Thus, in total, samples
129+
are taken at ``g * gskip`` steps from the full chain.
130+
(This is useful for sensitivity analysis where we run the chain with
131+
a small `g` value and save many samples and then use `gskip` to process
132+
samples at increasingly larger intervals without having to re-run the
133+
chain.) The default value of 1 means that the samples are processed at
134+
every `g` samples from the full chain.
135+
:type gskip: int
119136
120137
EXAMPLE
121138
-------
@@ -139,7 +156,7 @@ class Gibbs(object):
139156
"""
140157

141158
def __init__(self, times=None, residue=None, loc=0, ncomp=15, niter=110000,
142-
cutoff=None, g=50, burnin=10000, gskip=2):
159+
cutoff=None, g=100, burnin=10000, gskip=1):
143160
self.times = times
144161
self.residue = residue
145162
self.niter = niter

basicrta/tests/test_cluster.py

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest
88
import numpy as np
99
import sys
10+
import warnings
1011
from unittest.mock import patch, MagicMock, call
1112
from basicrta.cluster import ProcessProtein
1213
from basicrta.tests.utils import work_in
@@ -22,10 +23,11 @@ def test_init_with_default_values(self):
2223
assert pp.niter == 110000
2324
assert pp.prot == "test_protein"
2425
assert pp.cutoff == 7.0
25-
assert pp.gskip == 1000 # Default value from paper
26+
assert pp.gskip == 100 # Default value from paper
2627
assert pp.burnin == 10000 # Default value from paper
2728
assert pp.taus is None
2829
assert pp.bars is None
30+
assert pp.residues is None
2931

3032
def test_init_with_custom_values(self):
3133
"""Test initialization with custom gskip and burnin values."""
@@ -50,7 +52,7 @@ def test_getitem_method(self):
5052
assert pp["niter"] == 110000
5153
assert pp["prot"] == "test_protein"
5254
assert pp["cutoff"] == 7.0
53-
assert pp["gskip"] == 1000
55+
assert pp["gskip"] == 100
5456
assert pp["burnin"] == 10000
5557

5658
def test_single_residue_missing_file(self, tmp_path):
@@ -69,8 +71,9 @@ def test_single_residue_missing_file(self, tmp_path):
6971
assert tau == [0, 0, 0]
7072
assert result is None
7173

74+
@pytest.mark.parametrize("gskip", [111, 100, 50, 10])
7275
@patch('basicrta.cluster.Gibbs')
73-
def test_single_residue_with_file(self, mock_gibbs, tmp_path):
76+
def test_single_residue_with_file(self, mock_gibbs, tmp_path, gskip):
7477
"""Test _single_residue method when gibbs file exists."""
7578
# Create a mock directory structure
7679
residue_dir = tmp_path / "basicrta-7.0" / "R123"
@@ -83,23 +86,53 @@ def test_single_residue_with_file(self, mock_gibbs, tmp_path):
8386
# Configure the mock
8487
mock_gibbs_instance = MagicMock()
8588
mock_gibbs_instance.estimate_tau.return_value = [0.1, 1.5, 3.0]
89+
mock_gibbs_instance.g = 50
8690
mock_gibbs.return_value.load.return_value = mock_gibbs_instance
8791

88-
pp = ProcessProtein(niter=110000, prot="test_protein", cutoff=7.0)
92+
pp = ProcessProtein(niter=110000, prot="test_protein", gskip=gskip, cutoff=7.0)
8993

9094
# Call the method with processing enabled
91-
residue, tau, result = pp._single_residue(str(residue_dir), process=True)
95+
with warnings.catch_warnings():
96+
warnings.simplefilter("ignore", UserWarning)
97+
residue, tau, result = pp._single_residue(str(residue_dir), process=True)
9298

93-
# Verify results
9499
assert residue == "R123"
95100
assert tau == [0.1, 1.5, 3.0]
96101
assert result == str(gibbs_file)
97102

98-
# Verify the Gibbs object was configured correctly
99-
assert mock_gibbs_instance.gskip == 1000
103+
# Verify the Gibbs object was re-configured correctly
104+
ggskip = gskip // mock_gibbs_instance.g
105+
if ggskip < 1:
106+
ggskip = 1
107+
assert mock_gibbs_instance.gskip == ggskip
100108
assert mock_gibbs_instance.burnin == 10000
101109
mock_gibbs_instance.process_gibbs.assert_called_once()
102110

111+
@patch('basicrta.cluster.Gibbs')
112+
def test_single_residue_with_file_gskip_warning(self, mock_gibbs, tmp_path):
113+
"""Test _single_residue method warns when gskip is less than g."""
114+
# Create a mock directory structure
115+
residue_dir = tmp_path / "basicrta-7.0" / "R123"
116+
residue_dir.mkdir(parents=True)
117+
118+
# Create a mock gibbs pickle file
119+
gibbs_file = residue_dir / "gibbs_110000.pkl"
120+
gibbs_file.touch()
121+
122+
# Configure the mock
123+
mock_gibbs_instance = MagicMock()
124+
mock_gibbs_instance.estimate_tau.return_value = [0.1, 1.5, 3.0]
125+
mock_gibbs_instance.g = 50
126+
mock_gibbs.return_value.load.return_value = mock_gibbs_instance
127+
128+
pp = ProcessProtein(niter=110000, prot="test_protein", gskip=10, cutoff=7.0)
129+
with pytest.warns(UserWarning,
130+
match="WARNING: gskip=10 is less than g=50, setting gskip to 1"):
131+
residue, tau, result = pp._single_residue(str(residue_dir), process=True)
132+
133+
assert pp.gskip == 10
134+
assert mock_gibbs_instance.gskip == 1
135+
103136
@patch('basicrta.cluster.Gibbs')
104137
def test_single_residue_exception_handling(self, mock_gibbs, tmp_path):
105138
"""Test _single_residue method handles exceptions gracefully."""

basicrta/tests/test_combine_contacts.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,8 @@ def test_gibbs_sampler_integration(self, tmp_path, create_mock_contacts):
265265
ncomp=2, # Use 2 components for stability
266266
niter=1000, # 1000 steps as requested
267267
burnin=5, # 5 burnin steps as requested
268+
g=50,
269+
gskip=1,
268270
cutoff=7.0 # Set cutoff for directory creation
269271
)
270272

basicrta/tests/test_gibbs.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,16 @@ class TestGibbsSampler:
127127
'cutoff': 7.0,
128128
'g': 100
129129
},
130+
{
131+
'times': None, # Will be set from fixture
132+
'residue': 'W313',
133+
'ncomp': 2,
134+
'niter': 1000,
135+
'burnin': 5,
136+
'cutoff': 7.0,
137+
'g': 50,
138+
'gskip': 2
139+
},
130140
])
131141
def test_gibbs_run_method(self, tmp_path, synthetic_timeseries, init_kwargs):
132142
"""Test the run() method for Gibbs class with synthetic data."""

0 commit comments

Comments
 (0)