Skip to content

Commit 4cc7749

Browse files
Copilotorbeckst
andcommitted
Address PR review feedback: remove kinetic clustering references and use pytest fixtures
Co-authored-by: orbeckst <237980+orbeckst@users.noreply.github.com>
1 parent bd367fb commit 4cc7749

File tree

2 files changed

+44
-45
lines changed

2 files changed

+44
-45
lines changed

basicrta/gibbs.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def run(self, run_resids=None, g=100):
5959
print(f"WARNING: Using combined contact file with {metadata['n_trajectories']} trajectories.")
6060
print("WARNING: Kinetic clustering is not yet supported for combined contacts.")
6161
print("WARNING: The Gibbs sampler will pool all residence times together.")
62-
print("WARNING: Trajectory source information is available but not used in kinetic clustering.")
6362

6463
protids = np.unique(contacts[:, 0])
6564
if not run_resids:
@@ -238,14 +237,8 @@ def cluster(self, method="GaussianMixture", **kwargs):
238237
"""
239238
# Check if this Gibbs result was created from combined contact data
240239
if hasattr(self, '_from_combined_contacts') and self._from_combined_contacts:
241-
raise NotImplementedError(
242-
"Kinetic clustering is not yet supported for combined contact data. "
243-
"The trajectory source information needed for proper kinetic clustering "
244-
"is available in the combined contact files but not yet utilized in the "
245-
"clustering algorithm. For now, analyze each trajectory separately for "
246-
"kinetic clustering, or use the combined data only for residence time "
247-
"distribution analysis."
248-
)
240+
print("INFO: Using combined contact data for clustering. "
241+
"Trajectory source information is pooled together.")
249242

250243
from sklearn import mixture
251244
from scipy import stats

basicrta/tests/test_combine_contacts.py

Lines changed: 42 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -24,21 +24,21 @@ def __init__(self, resids):
2424
self.resids = np.array(resids)
2525

2626

27-
class TestCombineContacts:
28-
"""Test class for CombineContacts functionality."""
29-
30-
def setup_method(self):
31-
"""Set up test fixtures."""
32-
self.temp_dir = tempfile.mkdtemp()
33-
self.original_dir = os.getcwd()
34-
os.chdir(self.temp_dir)
35-
36-
def teardown_method(self):
37-
"""Clean up after tests."""
38-
os.chdir(self.original_dir)
39-
shutil.rmtree(self.temp_dir)
40-
41-
def create_mock_contacts(self, filename, n_contacts=100, cutoff=7.0,
27+
@pytest.fixture
28+
def temp_dir():
29+
"""Create a temporary directory for tests."""
30+
temp_dir = tempfile.mkdtemp()
31+
original_dir = os.getcwd()
32+
os.chdir(temp_dir)
33+
yield temp_dir
34+
os.chdir(original_dir)
35+
shutil.rmtree(temp_dir)
36+
37+
38+
@pytest.fixture
39+
def create_mock_contacts():
40+
"""Factory fixture for creating mock contact files."""
41+
def _create_mock_contacts(filename, n_contacts=100, cutoff=7.0,
4242
ts=0.1, traj_name="test.xtc", top_name="test.pdb"):
4343
"""Create a mock contact file for testing."""
4444
# Create mock atom groups that can be pickled
@@ -72,12 +72,18 @@ def create_mock_contacts(self, filename, n_contacts=100, cutoff=7.0,
7272
pickle.dump(contacts, f, protocol=5)
7373

7474
return contacts, metadata
75+
76+
return _create_mock_contacts
77+
78+
79+
class TestCombineContacts:
80+
"""Test class for CombineContacts functionality."""
7581

76-
def test_combine_contacts_basic(self):
82+
def test_combine_contacts_basic(self, temp_dir, create_mock_contacts):
7783
"""Test basic contact combination functionality."""
7884
# Create two mock contact files
79-
contacts1, meta1 = self.create_mock_contacts("contacts1.pkl", n_contacts=50)
80-
contacts2, meta2 = self.create_mock_contacts("contacts2.pkl", n_contacts=75,
85+
contacts1, meta1 = create_mock_contacts("contacts1.pkl", n_contacts=50)
86+
contacts2, meta2 = create_mock_contacts("contacts2.pkl", n_contacts=75,
8187
traj_name="test2.xtc")
8288

8389
# Combine them
@@ -110,10 +116,10 @@ def test_combine_contacts_basic(self):
110116
assert np.all(traj_sources[:50] == 0) # First 50 from file 0
111117
assert np.all(traj_sources[50:] == 1) # Next 75 from file 1
112118

113-
def test_incompatible_cutoffs(self):
119+
def test_incompatible_cutoffs(self, temp_dir, create_mock_contacts):
114120
"""Test that incompatible cutoffs raise an error."""
115-
self.create_mock_contacts("contacts1.pkl", cutoff=7.0)
116-
self.create_mock_contacts("contacts2.pkl", cutoff=8.0) # Different cutoff
121+
create_mock_contacts("contacts1.pkl", cutoff=7.0)
122+
create_mock_contacts("contacts2.pkl", cutoff=8.0) # Different cutoff
117123

118124
combiner = CombineContacts(
119125
contact_files=["contacts1.pkl", "contacts2.pkl"]
@@ -122,10 +128,10 @@ def test_incompatible_cutoffs(self):
122128
with pytest.raises(ValueError, match="Incompatible cutoffs"):
123129
combiner.run()
124130

125-
def test_incompatible_atom_groups(self):
131+
def test_incompatible_atom_groups(self, temp_dir, create_mock_contacts):
126132
"""Test that incompatible atom groups raise an error."""
127133
# Create first file with standard residues
128-
contacts1, _ = self.create_mock_contacts("contacts1.pkl")
134+
contacts1, _ = create_mock_contacts("contacts1.pkl")
129135

130136
# Create second file with different protein residues
131137
mock_ag1 = MockAtomGroup([10, 20, 30]) # Different resids
@@ -153,10 +159,10 @@ def test_incompatible_atom_groups(self):
153159
with pytest.raises(ValueError, match="Incompatible ag1 residues"):
154160
combiner.run()
155161

156-
def test_different_timesteps_warning(self, capsys):
162+
def test_different_timesteps_warning(self, temp_dir, create_mock_contacts, capsys):
157163
"""Test that different timesteps produce a warning."""
158-
self.create_mock_contacts("contacts1.pkl", ts=0.1)
159-
self.create_mock_contacts("contacts2.pkl", ts=0.2) # Different timestep
164+
create_mock_contacts("contacts1.pkl", ts=0.1)
165+
create_mock_contacts("contacts2.pkl", ts=0.2) # Different timestep
160166

161167
combiner = CombineContacts(
162168
contact_files=["contacts1.pkl", "contacts2.pkl"]
@@ -168,16 +174,16 @@ def test_different_timesteps_warning(self, capsys):
168174
captured = capsys.readouterr()
169175
assert "WARNING: Different timesteps detected" in captured.out
170176

171-
def test_minimum_files_required(self):
177+
def test_minimum_files_required(self, temp_dir, create_mock_contacts):
172178
"""Test that at least 2 files are required."""
173-
self.create_mock_contacts("contacts1.pkl")
179+
create_mock_contacts("contacts1.pkl")
174180

175181
with pytest.raises(ValueError, match="At least 2 contact files are required"):
176182
CombineContacts(contact_files=["contacts1.pkl"])
177183

178-
def test_missing_file(self):
184+
def test_missing_file(self, temp_dir, create_mock_contacts):
179185
"""Test handling of missing contact files."""
180-
self.create_mock_contacts("contacts1.pkl")
186+
create_mock_contacts("contacts1.pkl")
181187

182188
combiner = CombineContacts(
183189
contact_files=["contacts1.pkl", "nonexistent.pkl"]
@@ -186,10 +192,10 @@ def test_missing_file(self):
186192
with pytest.raises(FileNotFoundError, match="Contact file not found"):
187193
combiner.run()
188194

189-
def test_skip_validation(self):
195+
def test_skip_validation(self, temp_dir, create_mock_contacts):
190196
"""Test skipping compatibility validation."""
191-
self.create_mock_contacts("contacts1.pkl", cutoff=7.0)
192-
self.create_mock_contacts("contacts2.pkl", cutoff=8.0) # Different cutoff
197+
create_mock_contacts("contacts1.pkl", cutoff=7.0)
198+
create_mock_contacts("contacts2.pkl", cutoff=8.0) # Different cutoff
193199

194200
combiner = CombineContacts(
195201
contact_files=["contacts1.pkl", "contacts2.pkl"],
@@ -200,11 +206,11 @@ def test_skip_validation(self):
200206
output_file = combiner.run()
201207
assert os.path.exists(output_file)
202208

203-
def test_combined_contacts_detection(self):
209+
def test_combined_contacts_detection(self, temp_dir, create_mock_contacts):
204210
"""Test that combined contact files are properly detected."""
205211
# Create and combine contacts
206-
self.create_mock_contacts("contacts1.pkl", n_contacts=30)
207-
self.create_mock_contacts("contacts2.pkl", n_contacts=40, traj_name="test2.xtc")
212+
create_mock_contacts("contacts1.pkl", n_contacts=30)
213+
create_mock_contacts("contacts2.pkl", n_contacts=40, traj_name="test2.xtc")
208214

209215
combiner = CombineContacts(
210216
contact_files=["contacts1.pkl", "contacts2.pkl"],

0 commit comments

Comments
 (0)