Skip to content

Commit 567ae4b

Browse files
Copilotorbeckst
andauthored
Add support for combining contact timeseries from multiple repeat runs (#17)
* This PR implements functionality to combine contact timeseries from multiple repeat runs, enabling pooled analysis of binding kinetics data rather than analyzing each run separately. * Implement CombineContacts class and CLI for combining contact timeseries from multiple repeats * introduce additional entry in contact list for the trajectory where the data came from * add tests (new test for Gibbs sampler and tests for combining) * add new docs for combining data * update CHANGELOG, --------- Co-authored-by: copilot-swe-agent[bot] <[email protected]> Co-authored-by: Oliver Beckstein <[email protected]>
1 parent 000ebc5 commit 567ae4b

File tree

11 files changed

+738
-4
lines changed

11 files changed

+738
-4
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,6 @@
33
basicrta.egg-info/
44
dist/
55

6+
# Test artifacts
7+
basicrta-*/
8+
*.pkl

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,15 @@ The rules for this file:
2020

2121
### Authors
2222
<!-- GitHub usernames of contributors to this release -->
23+
* @copilot
24+
* @orbeckst
2325

2426
### Added
2527
<!-- New added features -->
28+
* Support for combining contact timeseries from multiple repeat runs through new
29+
`CombineContacts` class and `python -m basicrta.combine` CLI interface.
30+
Enables pooled analysis of binding kinetics data with metadata preservation
31+
and trajectory source tracking (Issue #16)
2632

2733
### Fixed
2834
<!-- Bug fixes -->

basicrta/combine.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#!/usr/bin/env python
2+
3+
"""
4+
Command-line interface for combining contact timeseries from multiple repeat runs.
5+
6+
This module provides functionality to combine contact files from multiple
7+
trajectory repeats, enabling pooled analysis of binding kinetics.
8+
"""
9+
10+
import os
11+
import argparse
12+
from basicrta.contacts import CombineContacts
13+
14+
15+
def main():
16+
"""Main function for combining contact files."""
17+
parser = argparse.ArgumentParser(
18+
description="Combine contact timeseries from multiple repeat runs. "
19+
"This enables pooling data from multiple trajectory repeats "
20+
"and calculating posteriors from all data together."
21+
)
22+
23+
parser.add_argument(
24+
'--contacts',
25+
nargs='+',
26+
required=True,
27+
help="List of contact pickle files to combine (e.g., contacts_7.0.pkl from different runs)"
28+
)
29+
30+
parser.add_argument(
31+
'--output',
32+
type=str,
33+
default='combined_contacts.pkl',
34+
help="Output filename for combined contacts (default: combined_contacts.pkl)"
35+
)
36+
37+
parser.add_argument(
38+
'--no-validate',
39+
action='store_true',
40+
help="Skip compatibility validation (use with caution)"
41+
)
42+
43+
args = parser.parse_args()
44+
45+
# Validate input files exist
46+
missing_files = []
47+
for filename in args.contacts:
48+
if not os.path.exists(filename):
49+
missing_files.append(filename)
50+
51+
if missing_files:
52+
print("ERROR: The following contact files were not found:")
53+
for filename in missing_files:
54+
print(f" {filename}")
55+
return 1
56+
57+
if len(args.contacts) < 2:
58+
print("ERROR: At least 2 contact files are required for combining")
59+
return 1
60+
61+
if os.path.exists(args.output):
62+
print(f"ERROR: Output file {args.output} already exists")
63+
return 1
64+
65+
try:
66+
combiner = CombineContacts(
67+
contact_files=args.contacts,
68+
output_name=args.output,
69+
validate_compatibility=not args.no_validate
70+
)
71+
72+
output_file = combiner.run()
73+
74+
print(f"\nCombination successful!")
75+
print(f"Combined contact file saved as: {output_file}")
76+
print(f"\nYou can now use this file with the Gibbs sampler:")
77+
print(f" python -m basicrta.gibbs --contacts {output_file} --nproc <N>")
78+
79+
return 0
80+
81+
except Exception as e:
82+
print(f"ERROR: {e}")
83+
return 1
84+
85+
86+
if __name__ == '__main__':
87+
exit(main())

basicrta/contacts.py

Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ def run(self):
193193
# os.remove('.tmpmap')
194194
# cfiles = glob.glob('.contacts*')
195195
# [os.remove(f) for f in cfiles]
196-
print(f'\nSaved contacts to "contacts_{self.cutoff}.npy"')
196+
print(f'\nSaved contacts to "contacts_{self.cutoff}.pkl"')
197197

198198
def _lipswap(self, lip, memarr, i):
199199
from basicrta.util import get_dec
@@ -232,6 +232,134 @@ def _lipswap(self, lip, memarr, i):
232232
return len(dset)
233233

234234

235+
class CombineContacts(object):
236+
"""Class to combine contact timeseries from multiple repeat runs.
237+
238+
This class enables pooling data from multiple trajectory repeats and
239+
calculating posteriors from all data together, rather than analyzing
240+
each run separately.
241+
242+
:param contact_files: List of contact pickle files to combine
243+
:type contact_files: list of str
244+
:param output_name: Name for the combined output file (default: 'combined_contacts.pkl')
245+
:type output_name: str, optional
246+
:param validate_compatibility: Whether to validate that files are compatible (default: True)
247+
:type validate_compatibility: bool, optional
248+
"""
249+
250+
def __init__(self, contact_files, output_name='combined_contacts.pkl',
251+
validate_compatibility=True):
252+
self.contact_files = contact_files
253+
self.output_name = output_name
254+
self.validate_compatibility = validate_compatibility
255+
256+
if len(contact_files) < 2:
257+
raise ValueError("At least 2 contact files are required for combining")
258+
259+
def _load_contact_file(self, filename):
260+
"""Load a contact pickle file and return data and metadata."""
261+
if not os.path.exists(filename):
262+
raise FileNotFoundError(f"Contact file not found: {filename}")
263+
264+
with open(filename, 'rb') as f:
265+
contacts = pickle.load(f)
266+
267+
metadata = contacts.dtype.metadata
268+
return contacts, metadata
269+
270+
def _validate_compatibility(self, metadatas):
271+
"""Validate that contact files are compatible for combining."""
272+
reference = metadatas[0]
273+
274+
# Check that all files have the same atom groups
275+
for i, meta in enumerate(metadatas[1:], 1):
276+
# Compare cutoff
277+
if meta['cutoff'] != reference['cutoff']:
278+
raise ValueError(f"Incompatible cutoffs: file 0 has {reference['cutoff']}, "
279+
f"file {i} has {meta['cutoff']}")
280+
281+
# Compare atom group selections by checking if resids match
282+
ref_ag1_resids = set(reference['ag1'].residues.resids)
283+
ref_ag2_resids = set(reference['ag2'].residues.resids)
284+
meta_ag1_resids = set(meta['ag1'].residues.resids)
285+
meta_ag2_resids = set(meta['ag2'].residues.resids)
286+
287+
if ref_ag1_resids != meta_ag1_resids:
288+
raise ValueError(f"Incompatible ag1 residues between file 0 and file {i}")
289+
if ref_ag2_resids != meta_ag2_resids:
290+
raise ValueError(f"Incompatible ag2 residues between file 0 and file {i}")
291+
292+
# Check timesteps and warn if different
293+
timesteps = [meta['ts'] for meta in metadatas]
294+
if not all(abs(ts - timesteps[0]) < 1e-6 for ts in timesteps):
295+
print("WARNING: Different timesteps detected across runs:")
296+
for i, (filename, ts) in enumerate(zip(self.contact_files, timesteps)):
297+
print(f" File {i} ({filename}): dt = {ts} ns")
298+
print("This may affect residence time estimates, especially for fast events.")
299+
300+
def run(self):
301+
"""Combine contact files and save the result."""
302+
print(f"Combining {len(self.contact_files)} contact files...")
303+
304+
all_contacts = []
305+
all_metadatas = []
306+
307+
# Load all contact files
308+
for i, filename in enumerate(self.contact_files):
309+
print(f"Loading file {i+1}/{len(self.contact_files)}: {filename}")
310+
contacts, metadata = self._load_contact_file(filename)
311+
all_contacts.append(contacts)
312+
all_metadatas.append(metadata)
313+
314+
# Validate compatibility if requested
315+
if self.validate_compatibility:
316+
print("Validating file compatibility...")
317+
self._validate_compatibility(all_metadatas)
318+
319+
# Combine contact data
320+
print("Combining contact data...")
321+
322+
# Calculate total size and create combined array
323+
total_size = sum(len(contacts) for contacts in all_contacts)
324+
reference_metadata = all_metadatas[0].copy()
325+
326+
# Extend metadata to include trajectory source information
327+
reference_metadata['source_files'] = self.contact_files
328+
reference_metadata['n_trajectories'] = len(self.contact_files)
329+
330+
# Determine number of columns (5 for raw contacts, 4 for processed)
331+
n_cols = all_contacts[0].shape[1]
332+
333+
# Create dtype with extended metadata
334+
combined_dtype = np.dtype(np.float64, metadata=reference_metadata)
335+
336+
# Add trajectory source column (will be last column)
337+
combined_contacts = np.zeros((total_size, n_cols + 1), dtype=np.float64)
338+
339+
# Combine data and add trajectory source information
340+
offset = 0
341+
for traj_idx, contacts in enumerate(all_contacts):
342+
n_contacts = len(contacts)
343+
# Copy original contact data
344+
combined_contacts[offset:offset+n_contacts, :n_cols] = contacts[:]
345+
# Add trajectory source index
346+
combined_contacts[offset:offset+n_contacts, n_cols] = traj_idx
347+
offset += n_contacts
348+
349+
# Create final memmap with proper dtype
350+
final_contacts = combined_contacts.view(combined_dtype)
351+
352+
# Save combined contacts
353+
print(f"Saving combined contacts to {self.output_name}...")
354+
final_contacts.dump(self.output_name, protocol=5)
355+
356+
print(f"Successfully combined {len(self.contact_files)} files into {self.output_name}")
357+
print(f"Total contacts: {total_size}")
358+
print(f"Added trajectory source column (index {n_cols}) for kinetic clustering support")
359+
360+
return self.output_name
361+
362+
235363
if __name__ == '__main__':
236364
"""DOCSSS
237365
"""

basicrta/gibbs.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,14 @@ def run(self, run_resids=None, g=100):
5252
with open(self.contacts, 'r+b') as f:
5353
contacts = pickle.load(f)
5454

55+
# Check if this is a combined contact file
56+
metadata = contacts.dtype.metadata
57+
is_combined = metadata and 'n_trajectories' in metadata and metadata['n_trajectories'] > 1
58+
if is_combined:
59+
print(f"WARNING: Using combined contact file with {metadata['n_trajectories']} trajectories.")
60+
print("WARNING: Kinetic clustering is not yet supported for combined contacts.")
61+
print("WARNING: The Gibbs sampler will pool all residence times together.")
62+
5563
protids = np.unique(contacts[:, 0])
5664
if not run_resids:
5765
run_resids = protids
@@ -71,7 +79,7 @@ def run(self, run_resids=None, g=100):
7179
run_resids])
7280
residues = residues[inds]
7381
input_list = [[residues[i], times[i].copy(), i % self.nproc,
74-
self.ncomp, self.niter, self.cutoff, g] for i in
82+
self.ncomp, self.niter, self.cutoff, g, is_combined] for i in
7583
range(len(residues))]
7684

7785
del contacts, times
@@ -227,6 +235,11 @@ def cluster(self, method="GaussianMixture", **kwargs):
227235
:param method: Mixture method to use
228236
:type method: str
229237
"""
238+
# Check if this Gibbs result was created from combined contact data
239+
if hasattr(self, '_from_combined_contacts') and self._from_combined_contacts:
240+
print("INFO: Using combined contact data for clustering. "
241+
"Trajectory source information is pooled together.")
242+
230243
from sklearn import mixture
231244
from scipy import stats
232245

0 commit comments

Comments
 (0)