Skip to content

Commit 38a5297

Browse files
authored
working cli (#46)
* initial working cli * taking parser from submodules * in-progress cli * working cli that takes parser from each script * forgot to add script execution, should be working now * removed subprocess call * removed unnecessary print statement * added help strings/subcommand description to cli * added argument requirements * added help * added cli addition to CHANGELOG * added documentation * added a few tests * added test * fixed test * referenced issue in CHANGELOG
1 parent 777f5f5 commit 38a5297

File tree

12 files changed

+463
-62
lines changed

12 files changed

+463
-62
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,17 @@ The rules for this file:
1919

2020
### Authors
2121
* @orbeckst
22+
* @rsexton2
2223

2324
### Fixed
2425
* Have cluster.ProcessProtein.reprocess() record "no result" if
2526
the gibbs.Gibbs.process_gibbs() step fails due to insufficient
2627
number of samples. Otherwise `python -m cluster` fails to process
2728
whole proteins.
2829

30+
### Added
31+
* Added command-line interface for basicrta workflow (Issue #20)
32+
2933
## [1.1.3] - 2025-09-11
3034

3135
### Authors

basicrta/cli.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
"""
2+
Command line functionality of basicrta.
3+
4+
The `main()` function of this module gets the argument parser from each of the
5+
scripts below and executes the `main()` function of the module called. The
6+
function also collects help from the subparsers and provides it at the command
7+
line.
8+
9+
Modules callable from the cli: contacts.py, gibbs.py, cluster.py, kinetics.py,
10+
combine.py.
11+
"""
12+
13+
from importlib.metadata import version
14+
import basicrta
15+
import argparse
16+
import subprocess
17+
import importlib
18+
import sys
19+
20+
__version__ = version("basicrta")
21+
22+
# define which scripts can be ran from cli
23+
# can easily add functionality to cli as modules are added
24+
commands = ['contacts', 'gibbs', 'cluster', 'combine', 'kinetics']
25+
26+
def main():
27+
""" This module provides the functionality for a command line interface for
28+
basicrta scripts. The scripts available to the cli are:
29+
30+
* contacts.py
31+
* gibbs.py
32+
* cluster.py
33+
* combine.py
34+
* kinetics.py
35+
36+
Each script is called and ran using the `main()` function of each module and
37+
the parser is passed to the cli using the `get_parser()` function. Any
38+
module added to the cli needs to have both functions.
39+
"""
40+
parser = argparse.ArgumentParser(prog='basicrta', add_help=True)
41+
subparsers = parser.add_subparsers(help="""step in the basicrta workflow to
42+
execute""")
43+
44+
# collect parser from each script in `commands`
45+
for command in commands:
46+
subparser = importlib.import_module(f"basicrta.{command}").get_parser()
47+
subparsers.add_parser(f'{command}', parents=[subparser], add_help=True,
48+
description=subparser.description,
49+
conflict_handler='resolve',
50+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
51+
help=subparser.description)
52+
53+
# print subparser help if no arguments given
54+
if len(sys.argv) == 2 and sys.argv[1] in commands:
55+
subparsers.choices[f'{sys.argv[1]}'].print_help()
56+
sys.exit()
57+
58+
# print basicrta help if no subcommand given
59+
parser.parse_args(args=None if sys.argv[1:] else ['--help'])
60+
61+
# execute basicrta script
62+
importlib.import_module(f"basicrta.{sys.argv[1]}").main()
63+
64+
if __name__ == "__main__":
65+
main()

basicrta/cluster.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""This module provides the ProcessProtein class, which collects and processes
2+
Gibbs sampler data.
3+
"""
4+
15
import os
26
import gc
37
import warnings
@@ -11,9 +15,6 @@
1115
from basicrta.gibbs import Gibbs
1216
gc.enable()
1317

14-
"""This module provides the ProcessProtein class, which collects and processes
15-
Gibbs sampler data.
16-
"""
1718

1819
class ProcessProtein(object):
1920
r"""ProcessProtein is the class that collects and processes Gibbs sampler
@@ -237,33 +238,54 @@ def b_color_structure(self, structure):
237238

238239
u.select_atoms('protein').write('tau_bcolored.pdb')
239240

240-
241-
if __name__ == "__main__": #pragma: no cover
242-
# the script is tested in the test_cluster.py but cannot be accounted for
243-
# in the coverage report
241+
def get_parser():
244242
import argparse
245-
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
246-
parser.add_argument('--nproc', type=int, default=1)
247-
parser.add_argument('--cutoff', type=float)
248-
parser.add_argument('--niter', type=int, default=110000)
249-
parser.add_argument('--prot', type=str, default=None, nargs='?')
250-
parser.add_argument('--label-cutoff', type=float, default=3,
251-
dest='label_cutoff',
252-
help='Only label residues with tau > '
253-
'LABEL-CUTOFF * <tau>. ')
254-
parser.add_argument('--structure', type=str, nargs='?')
243+
parser = argparse.ArgumentParser(description="""perform clustering for each
244+
residue located in basicrta-{cutoff}/""",
245+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
246+
required = parser.add_argument_group('required arguments')
247+
248+
required.add_argument('--cutoff', required=True, type=float, help="""cutoff
249+
used in contact analysis, will cluster results in
250+
basicrta-{cutoff}/""")
251+
parser.add_argument('--nproc', type=int, default=1, help="""number of
252+
processes to use in multiprocessing""")
253+
parser.add_argument('--niter', type=int, default=110000, help="""number of
254+
iterations used in the gibbs sampler, used to load
255+
gibbs_{niter}.pkl""")
256+
parser.add_argument('--prot', type=str, nargs='?', help="""name of protein
257+
in tm_dict.txt, used to draw TM bars in tau vs resid
258+
plot""")
259+
parser.add_argument('--label_cutoff', type=float, default=3,
260+
dest='label_cutoff',
261+
help="""Only label residues with tau >
262+
LABEL-CUTOFF * <tau>.""")
263+
parser.add_argument('--structure', type=str, nargs='?', help="""will add tau
264+
as bfactors to the structure if provided""")
255265
# use for default values
256266
parser.add_argument('--gskip', type=int, default=100,
257267
help='Gibbs skip parameter for decorrelated samples;'
258268
'default from https://pubs.acs.org/doi/10.1021/acs.jctc.4c01522')
259269
parser.add_argument('--burnin', type=int, default=10000,
260270
help='Burn-in parameter, drop first N samples as equilibration;'
261271
'default from https://pubs.acs.org/doi/10.1021/acs.jctc.4c01522')
272+
# this is to make the cli work, should be just a temporary solution
273+
parser.add_argument('cluster', nargs='?', help=argparse.SUPPRESS)
274+
return parser
262275

276+
def main():
277+
parser = get_parser()
263278
args = parser.parse_args()
264279

265280
pp = ProcessProtein(args.niter, args.prot, args.cutoff,
266281
gskip=args.gskip, burnin=args.burnin)
267282
pp.reprocess(nproc=args.nproc)
268283
pp.write_data()
269284
pp.plot_protein(label_cutoff=args.label_cutoff)
285+
286+
287+
if __name__ == "__main__": #pragma: no cover
288+
# the script is tested in the test_cluster.py but cannot be accounted for
289+
# in the coverage report
290+
exit(main())
291+

basicrta/combine.py

Lines changed: 148 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,162 @@
11
#!/usr/bin/env python
22

33
"""
4-
Command-line interface for combining contact timeseries from multiple repeat runs.
4+
Combine contact timeseries from multiple repeat runs.
55
66
This module provides functionality to combine contact files from multiple
77
trajectory repeats, enabling pooled analysis of binding kinetics.
88
"""
99

1010
import os
1111
import argparse
12-
from basicrta.contacts import CombineContacts
1312

13+
class CombineContacts(object):
14+
"""Class to combine contact timeseries from multiple repeat runs.
15+
16+
This class enables pooling data from multiple trajectory repeats and
17+
calculating posteriors from all data together, rather than analyzing
18+
each run separately.
19+
20+
:param contact_files: List of contact pickle files to combine
21+
:type contact_files: list of str
22+
:param output_name: Name for the combined output file (default: 'combined_contacts.pkl')
23+
:type output_name: str, optional
24+
:param validate_compatibility: Whether to validate that files are compatible (default: True)
25+
:type validate_compatibility: bool, optional
26+
"""
27+
28+
def __init__(self, contact_files, output_name='combined_contacts.pkl',
29+
validate_compatibility=True):
30+
self.contact_files = contact_files
31+
self.output_name = output_name
32+
self.validate_compatibility = validate_compatibility
33+
34+
if len(contact_files) < 2:
35+
raise ValueError("At least 2 contact files are required for combining")
36+
37+
def _load_contact_file(self, filename):
38+
"""Load a contact pickle file and return data and metadata."""
39+
if not os.path.exists(filename):
40+
raise FileNotFoundError(f"Contact file not found: {filename}")
41+
42+
with open(filename, 'rb') as f:
43+
contacts = pickle.load(f)
44+
45+
metadata = contacts.dtype.metadata
46+
return contacts, metadata
47+
48+
def _validate_compatibility(self, metadatas):
49+
"""Validate that contact files are compatible for combining."""
50+
reference = metadatas[0]
51+
52+
# Check that all files have the same atom groups
53+
for i, meta in enumerate(metadatas[1:], 1):
54+
# Compare cutoff
55+
if meta['cutoff'] != reference['cutoff']:
56+
raise ValueError(f"Incompatible cutoffs: file 0 has {reference['cutoff']}, "
57+
f"file {i} has {meta['cutoff']}")
58+
59+
# Compare atom group selections by checking if resids match
60+
ref_ag1_resids = set(reference['ag1'].residues.resids)
61+
ref_ag2_resids = set(reference['ag2'].residues.resids)
62+
meta_ag1_resids = set(meta['ag1'].residues.resids)
63+
meta_ag2_resids = set(meta['ag2'].residues.resids)
64+
65+
if ref_ag1_resids != meta_ag1_resids:
66+
raise ValueError(f"Incompatible ag1 residues between file 0 and file {i}")
67+
if ref_ag2_resids != meta_ag2_resids:
68+
raise ValueError(f"Incompatible ag2 residues between file 0 and file {i}")
69+
70+
# Check timesteps and warn if different
71+
timesteps = [meta['ts'] for meta in metadatas]
72+
if not all(abs(ts - timesteps[0]) < 1e-6 for ts in timesteps):
73+
print("WARNING: Different timesteps detected across runs:")
74+
for i, (filename, ts) in enumerate(zip(self.contact_files, timesteps)):
75+
print(f" File {i} ({filename}): dt = {ts} ns")
76+
print("This may affect residence time estimates, especially for fast events.")
77+
78+
def run(self):
79+
"""Combine contact files and save the result."""
80+
print(f"Combining {len(self.contact_files)} contact files...")
81+
82+
all_contacts = []
83+
all_metadatas = []
84+
85+
# Load all contact files
86+
for i, filename in enumerate(self.contact_files):
87+
print(f"Loading file {i+1}/{len(self.contact_files)}: {filename}")
88+
contacts, metadata = self._load_contact_file(filename)
89+
all_contacts.append(contacts)
90+
all_metadatas.append(metadata)
91+
92+
# Validate compatibility if requested
93+
if self.validate_compatibility:
94+
print("Validating file compatibility...")
95+
self._validate_compatibility(all_metadatas)
96+
97+
# Combine contact data
98+
print("Combining contact data...")
99+
100+
# Calculate total size and create combined array
101+
total_size = sum(len(contacts) for contacts in all_contacts)
102+
reference_metadata = all_metadatas[0].copy()
103+
104+
# Extend metadata to include trajectory source information
105+
reference_metadata['source_files'] = self.contact_files
106+
reference_metadata['n_trajectories'] = len(self.contact_files)
107+
108+
# Determine number of columns (5 for raw contacts, 4 for processed)
109+
n_cols = all_contacts[0].shape[1]
110+
111+
# Create dtype with extended metadata
112+
combined_dtype = np.dtype(np.float64, metadata=reference_metadata)
113+
114+
# Add trajectory source column (will be last column)
115+
combined_contacts = np.zeros((total_size, n_cols + 1), dtype=np.float64)
116+
117+
# Combine data and add trajectory source information
118+
offset = 0
119+
for traj_idx, contacts in enumerate(all_contacts):
120+
n_contacts = len(contacts)
121+
# Copy original contact data
122+
combined_contacts[offset:offset+n_contacts, :n_cols] = contacts[:]
123+
# Add trajectory source index
124+
combined_contacts[offset:offset+n_contacts, n_cols] = traj_idx
125+
offset += n_contacts
126+
127+
# Create final memmap with proper dtype
128+
final_contacts = combined_contacts.view(combined_dtype)
129+
130+
# Save combined contacts
131+
print(f"Saving combined contacts to {self.output_name}...")
132+
final_contacts.dump(self.output_name, protocol=5)
133+
134+
print(f"Successfully combined {len(self.contact_files)} files into {self.output_name}")
135+
print(f"Total contacts: {total_size}")
136+
print(f"Added trajectory source column (index {n_cols}) for kinetic clustering support")
137+
138+
return self.output_name
14139

15-
def main():
16-
"""Main function for combining contact files."""
140+
def get_parser():
141+
"""Create parser, parse command line arguments, and return ArgumentParser
142+
object.
143+
144+
:return: An ArgumentParser instance with command line arguments stored.
145+
:rtype: `ArgumentParser` object
146+
"""
17147
parser = argparse.ArgumentParser(
18148
description="Combine contact timeseries from multiple repeat runs. "
19149
"This enables pooling data from multiple trajectory repeats "
20150
"and calculating posteriors from all data together."
21151
)
22152

23-
parser.add_argument(
153+
required = parser.add_argument_group('required arguments')
154+
required.add_argument(
24155
'--contacts',
25156
nargs='+',
26157
required=True,
27-
help="List of contact pickle files to combine (e.g., contacts_7.0.pkl from different runs)"
158+
help="""List of contact pickle files to combine (e.g., contacts_7.0.pkl
159+
from different runs)""",
28160
)
29161

30162
parser.add_argument(
@@ -39,7 +171,14 @@ def main():
39171
action='store_true',
40172
help="Skip compatibility validation (use with caution)"
41173
)
42-
174+
# this is to make the cli work, should be just a temporary solution
175+
parser.add_argument('combine', nargs='?', help=argparse.SUPPRESS)
176+
return parser
177+
178+
def main():
179+
"""Execute this function when this script is called from the command line.
180+
"""
181+
parser = get_parser()
43182
args = parser.parse_args()
44183

45184
# Validate input files exist
@@ -82,6 +221,5 @@ def main():
82221
print(f"ERROR: {e}")
83222
return 1
84223

85-
86-
if __name__ == '__main__':
87-
exit(main())
224+
if __name__ == "__main__":
225+
exit(main())

0 commit comments

Comments
 (0)