Skip to content

Commit 89edb87

Browse files
authored
Merge pull request #267 from Jhsmit/cli_process
Jobfiles and process from the command line
2 parents d949888 + de28097 commit 89edb87

File tree

13 files changed

+294
-60
lines changed

13 files changed

+294
-60
lines changed

docs/installation.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ Generate conda requirements files `from setup.cfg`:
142142
$ python _requirements.py
143143
144144
145-
First, if you would like a specific PyTorch version to use with PyHDX (ie CUDA/ROCm support), you should install this first.
145+
If you would like a specific PyTorch version to use with PyHDX (ie CUDA/ROCm support), you should install this first.
146146
Installation instructions are on the Pytorch_ website.
147147

148148
Then, install the other base dependencies and optional extras. For example, to install PyHDX with web app:

pyhdx/batch_processing.py

Lines changed: 178 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,31 @@
11
import warnings
2+
from functools import reduce
23
from pathlib import Path
34
import os
4-
from pyhdx.models import PeptideMasterTable, HDXMeasurement, HDXMeasurementSet
5-
from pyhdx.fileIO import read_dynamx
5+
import re
66

7+
from pyhdx import TorchFitResult
8+
from pyhdx.models import PeptideMasterTable, HDXMeasurement, HDXMeasurementSet
9+
from pyhdx.fileIO import read_dynamx, csv_to_dataframe, save_fitresult
10+
from pyhdx.fitting import fit_rates_half_time_interpolate, fit_rates_weighted_average, \
11+
fit_gibbs_global, fit_gibbs_global_batch, RatesFitResult, GenericFitResult
12+
import param
13+
import pandas as pd
14+
from pyhdx.support import gen_subclasses
15+
import yaml
716

817
time_factors = {"s": 1, "m": 60.0, "min": 60.0, "h": 3600, "d": 86400}
918
temperature_offsets = {"c": 273.15, "celsius": 273.15, "k": 0, "kelvin": 0}
1019

11-
# todo add data filters in yaml spec
20+
21+
# todo add data filters in state spec?
1222
# todo add proline, n_term options
13-
class YamlParser(object):
14-
""'object used to parse yaml data input files into PyHDX HDX Measurement object'
23+
class StateParser(object):
24+
""'object used to parse yaml state input files into PyHDX HDX Measurement object'
1525

16-
def __init__(self, yaml_dict, data_src=None, data_filters=None):
17-
self.yaml_dict = yaml_dict
26+
# todo yaml_dict -> state_spec
27+
def __init__(self, state_spec, data_src=None, data_filters=None):
28+
self.state_spec = state_spec
1829
if isinstance(data_src, (os.PathLike, str)):
1930
self.data_src = Path(data_src)
2031
elif isinstance(data_src, dict):
@@ -44,7 +55,7 @@ def load_data(self, *filenames, reader='dynamx'):
4455
def load_hdxmset(self):
4556
"""batch read the full yaml spec into a hdxmeasurementset"""
4657
hdxm_list = []
47-
for state in self.yaml_dict.keys():
58+
for state in self.state_spec.keys():
4859
hdxm = self.load_hdxm(state, name=state)
4960
hdxm_list.append(hdxm)
5061

@@ -55,7 +66,7 @@ def load_hdxm(self, state, **kwargs):
5566
kwargs: additional kwargs passed to hdxmeasurementset
5667
"""
5768

58-
state_dict = self.yaml_dict[state]
69+
state_dict = self.state_spec[state]
5970

6071
filenames = state_dict["filenames"]
6172
df = self.load_data(*filenames)
@@ -95,8 +106,8 @@ def load_hdxm(self, state, **kwargs):
95106
raise ValueError("Must specify either 'c_term' or 'sequence'")
96107

97108
state_data = pmt.get_state(state_dict["state"])
98-
for filter in self.data_filters:
99-
state_data = filter(state_data)
109+
for flt in self.data_filters:
110+
state_data = flt(state_data)
100111

101112
hdxm = HDXMeasurement(
102113
state_data,
@@ -111,16 +122,169 @@ def load_hdxm(self, state, **kwargs):
111122
return hdxm
112123

113124

125+
process_functions = {
126+
'csv_to_dataframe': csv_to_dataframe,
127+
'fit_rates_half_time_interpolate': fit_rates_half_time_interpolate,
128+
'fit_rates_weighted_average': fit_rates_weighted_average,
129+
'fit_gibbs_global': fit_gibbs_global
130+
131+
}
132+
133+
# task objects should be param
134+
class Task(param.Parameterized):
135+
...
136+
137+
scheduler_address = param.String(doc='Optional scheduler adress for dask task')
138+
139+
cwd = param.ClassSelector(Path, doc='Path of the current working directory')
140+
141+
142+
class LoadHDMeasurementSetTask(Task):
143+
_type = 'load_hdxm_set'
144+
145+
state_file = param.String() # = string path
146+
147+
out = param.ClassSelector(HDXMeasurementSet)
148+
149+
def execute(self, *args, **kwargs):
150+
state_spec = yaml.safe_load((self.cwd / self.state_file).read_text())
151+
parser = StateParser(state_spec, self.cwd, default_filters)
152+
hdxm_set = parser.load_hdxmset()
153+
154+
self.out = hdxm_set
155+
156+
157+
class EstimateRates(Task):
158+
_type = 'estimate_rates'
159+
160+
hdxm_set = param.ClassSelector(HDXMeasurementSet)
161+
162+
select_state = param.String(doc='If set, only use this state for creating initial guesses')
163+
164+
out = param.ClassSelector((RatesFitResult, GenericFitResult))
165+
166+
def execute(self, *args, **kwargs):
167+
if self.select_state: # refactor to 'state' ?
168+
hdxm = self.hdxm_set.get(self.select_state)
169+
result = fit_rates_half_time_interpolate(hdxm)
170+
else:
171+
results = []
172+
for hdxm in self.hdxm_set:
173+
r = fit_rates_half_time_interpolate(hdxm)
174+
results.append(r)
175+
result = RatesFitResult(results)
176+
177+
self.out = result
178+
179+
180+
# todo allow guesses from deltaG
181+
class ProcessGuesses(Task):
182+
_type = 'create_guess'
183+
184+
hdxm_set = param.ClassSelector(HDXMeasurementSet)
185+
186+
select_state = param.String(doc='If set, only use this state for creating initial guesses')
187+
188+
rates_df = param.ClassSelector(pd.DataFrame)
189+
190+
out = param.ClassSelector((pd.Series, pd.DataFrame))
191+
192+
def execute(self, *args, **kwargs):
193+
if self.select_state:
194+
hdxm = self.hdxm_set.get(self.select_state)
195+
if self.rates_df.columns.nlevels == 2:
196+
rates_series = self.rates_df[(self.select_state, 'rate')]
197+
else:
198+
rates_series = self.rates_df['rate']
199+
200+
guess = hdxm.guess_deltaG(rates_series)
201+
202+
else:
203+
rates = self.rates_df.xs('rate', level=-1, axis=1)
204+
guess = self.hdxm_set.guess_deltaG(rates)
205+
206+
self.out = guess
207+
208+
209+
class FitGlobalBatch(Task):
210+
_type = 'fit_global_batch'
211+
212+
hdxm_set = param.ClassSelector(HDXMeasurementSet)
213+
214+
initial_guess = param.ClassSelector(
215+
(pd.Series, pd.DataFrame), doc='Initial guesses for fits')
216+
217+
out = param.ClassSelector(TorchFitResult)
218+
219+
def execute(self, *args, **kwargs):
220+
result = fit_gibbs_global_batch(self.hdxm_set, self.initial_guess, **kwargs)
221+
222+
self.out = result
223+
224+
225+
class SaveFitResult(Task):
226+
_type = 'save_fit_result'
227+
228+
fit_result = param.ClassSelector(TorchFitResult)
229+
230+
output_dir = param.String()
231+
232+
def execute(self, *args, **kwargs):
233+
save_fitresult(self.cwd / self.output_dir, self.fit_result)
234+
235+
236+
class JobParser(object):
237+
238+
cwd = param.ClassSelector(Path, doc='Path of the current working directory')
239+
240+
def __init__(self, job_spec, cwd=None):
241+
self.job_spec = job_spec
242+
self.cwd = cwd or Path().cwd()
243+
244+
self.tasks = {}
245+
self.task_classes = {cls._type: cls for cls in gen_subclasses(Task) if getattr(cls, "_type", None)}
246+
247+
def resolve_var(self, var_string):
248+
task_name, *attrs = var_string.split('.')
249+
250+
return reduce(getattr, attrs, self.tasks[task_name])
251+
252+
def execute(self):
253+
254+
for task_spec in self.job_spec['steps']:
255+
task_klass = self.task_classes[task_spec['task']]
256+
skip = {'args', 'kwargs', 'task'}
257+
258+
resolved_params = {}
259+
for par_name in task_spec.keys() - skip:
260+
value = task_spec[par_name]
261+
if isinstance(value, str):
262+
m = re.findall(r'\$\((.*?)\)', value)
263+
if m:
264+
value = self.resolve_var(m[0])
265+
resolved_params[par_name] = value
266+
task = task_klass(cwd=self.cwd, **resolved_params)
267+
task.execute(*task_spec.get('args', []), **task_spec.get('kwargs', {}))
268+
269+
self.tasks[task.name] = task
270+
271+
114272
def yaml_to_hdxmset(yaml_dict, data_dir=None, **kwargs):
115273
"""reads files according to `yaml_dict` spec from `data_dir into HDXMEasurementSet"""
116274

275+
warnings.warn("yaml_to_hdxmset is deprecated, use 'StateParser'")
117276
hdxm_list = []
118277
for k, v in yaml_dict.items():
119278
hdxm = yaml_to_hdxm(v, data_dir=data_dir, name=k)
120279
hdxm_list.append(hdxm)
121280

122281
return HDXMeasurementSet(hdxm_list)
123282

283+
# todo configurable
284+
default_filters = [
285+
lambda df: df.query('exposure > 0')
286+
]
287+
124288

125289
def yaml_to_hdxm(yaml_dict, data_dir=None, data_filters=None, **kwargs):
126290
# todo perhas classmethod on HDXMeasurement object?
@@ -142,7 +306,7 @@ def yaml_to_hdxm(yaml_dict, data_dir=None, data_filters=None, **kwargs):
142306
Output data object as specified by `yaml_dict`.
143307
"""
144308

145-
warnings.warn('This method is deprecated in favor of YamlParser', DeprecationWarning)
309+
warnings.warn('This method is deprecated in favor of StateParser', DeprecationWarning)
146310

147311
if data_dir is not None:
148312
input_files = [Path(data_dir) / fname for fname in yaml_dict["filenames"]]
@@ -270,3 +434,5 @@ def load_from_yaml_v040b2(yaml_dict, data_dir=None, **kwargs): # pragma: no cov
270434
)
271435

272436
return hdxm
437+
438+

pyhdx/cli.py

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,30 @@
1-
import argparse
21
import time
3-
from ipaddress import ip_address
4-
from pyhdx.web import serve
5-
from pyhdx.config import cfg
6-
from pyhdx.local_cluster import verify_cluster, default_cluster
2+
from typing import Union, Optional
3+
from pathlib import Path
74

5+
import typer
6+
from ipaddress import ip_address
7+
import yaml
88

9-
# todo add check to see if the web module requirements are installed
109

10+
app = typer.Typer()
1111

12-
def main():
13-
parser = argparse.ArgumentParser(prog="pyhdx", description="PyHDX Launcher")
12+
@app.command()
13+
def serve(scheduler_address: Optional[str] = typer.Option(None, help="Address for dask scheduler to use")):
14+
"""Launch the PyHDX web application"""
1415

15-
parser.add_argument("serve", help="Runs PyHDX Dashboard")
16-
parser.add_argument(
17-
"--scheduler_address", help="Run with local cluster <ip>:<port>"
18-
)
19-
args = parser.parse_args()
16+
from pyhdx.config import cfg
17+
from pyhdx.local_cluster import verify_cluster, default_cluster
2018

21-
if args.scheduler_address:
22-
ip, port = args.scheduler_address.split(":")
19+
if scheduler_address is not None:
20+
ip, port = scheduler_address.split(":")
2321
if not ip_address(ip):
2422
print("Invalid IP Address")
2523
return
2624
elif not 0 <= int(port) < 2 ** 16:
2725
print("Invalid port, must be 0-65535")
2826
return
29-
cfg.set("cluster", "scheduler_address", args.scheduler_address)
27+
cfg.set("cluster", "scheduler_address", scheduler_address)
3028

3129
scheduler_address = cfg.get("cluster", "scheduler_address")
3230
if not verify_cluster(scheduler_address):
@@ -37,8 +35,9 @@ def main():
3735
scheduler_address = f"{ip}:{port}"
3836
print(f"Started new Dask LocalCluster at {scheduler_address}")
3937

40-
if args.serve:
41-
serve.run_apps()
38+
# Start the PyHDX web application
39+
from pyhdx.web import serve as serve_pyhdx
40+
serve_pyhdx.run_apps()
4241

4342
loop = True
4443
while loop:
@@ -49,11 +48,22 @@ def main():
4948
loop = False
5049

5150

52-
if __name__ == "__main__":
53-
import sys
51+
@app.command()
52+
def process(
53+
jobfile: Path = typer.Argument(..., help="Path to .yaml jobfile"),
54+
cwd: Optional[Path] = typer.Option(None, help="Optional path to working directory")
55+
):
56+
"""
57+
Process a HDX dataset according to a jobfile
58+
"""
59+
60+
from pyhdx.batch_processing import JobParser
5461

55-
sys.argv.append("serve")
56-
sys.argv.append("--scheduler_address")
57-
sys.argv.append("127.0.0.1:53270")
62+
job_spec = yaml.safe_load(jobfile.read_text())
63+
parser = JobParser(job_spec, cwd=cwd)
5864

59-
main()
65+
parser.execute()
66+
67+
68+
if __name__ == "__main__":
69+
app()

pyhdx/fileIO.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ def save_fitresult(output_dir, fit_result, log_lines=None):
371371
dataframe_to_file(output_dir / "losses.csv", fit_result.losses)
372372
dataframe_to_file(output_dir / "losses.txt", fit_result.losses, fmt="pprint")
373373

374-
if isinstance(fit_result.hdxm_set, pyhdx.HDXMeasurement):
374+
if isinstance(fit_result.hdxm_set, pyhdx.HDXMeasurement): # check, but this should always be hdxm_set
375375
fit_result.hdxm_set.to_file(output_dir / "HDXMeasurement.csv")
376376
if isinstance(fit_result.hdxm_set, pyhdx.HDXMeasurementSet):
377377
fit_result.hdxm_set.to_file(output_dir / "HDXMeasurements.csv")

pyhdx/fitting.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,7 @@ class GenericFitResult:
929929

930930
@dataclass
931931
class RatesFitResult:
932+
"""Accumulates multiple Generic/KineticsFit Results"""
932933
results: list
933934

934935
@property

pyhdx/models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,6 +1219,12 @@ def __iter__(self):
12191219
def __getitem__(self, item):
12201220
return self.hdxm_list.__getitem__(item)
12211221

1222+
def get(self, name):
1223+
"""find a HDXMeasurement by name"""
1224+
1225+
idx = self.names.index(name)
1226+
return self[idx]
1227+
12221228
@property
12231229
def Ns(self):
12241230
return len(self.hdxm_list)

pyhdx/web/controllers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from proplot import to_hex
2121
from skimage.filters import threshold_multiotsu
2222

23-
from pyhdx.batch_processing import YamlParser
23+
from pyhdx.batch_processing import StateParser
2424
from pyhdx.config import cfg
2525
from pyhdx.fileIO import read_dynamx, csv_to_dataframe, dataframe_to_stringio
2626
from pyhdx.fitting import (
@@ -499,7 +499,7 @@ def _add_dataset_batch(self):
499499
ios = {name: StringIO(byte_content.decode("UTF-8")) for name, byte_content in zip(self.widgets['input_files'].filename, self.input_files)}
500500
filters = [lambda df: df.query('exposure > 0')]
501501

502-
parser = YamlParser(yaml_dict, data_src=ios, data_filters=filters)
502+
parser = StateParser(yaml_dict, data_src=ios, data_filters=filters)
503503

504504
for state in yaml_dict.keys():
505505
hdxm = parser.load_hdxm(state, name=state)

0 commit comments

Comments
 (0)