Skip to content

Commit 0cce9bb

Browse files
committed
add JobParser and template example usage
1 parent eb923b0 commit 0cce9bb

File tree

11 files changed

+245
-22
lines changed

11 files changed

+245
-22
lines changed

pyhdx/batch_processing.py

Lines changed: 179 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,31 @@
11
import warnings
2+
from functools import reduce
23
from pathlib import Path
34
import os
4-
from pyhdx.models import PeptideMasterTable, HDXMeasurement, HDXMeasurementSet
5-
from pyhdx.fileIO import read_dynamx
5+
import re
66

7+
from pyhdx import TorchFitResult
8+
from pyhdx.models import PeptideMasterTable, HDXMeasurement, HDXMeasurementSet
9+
from pyhdx.fileIO import read_dynamx, csv_to_dataframe, save_fitresult
10+
from pyhdx.fitting import fit_rates_half_time_interpolate, fit_rates_weighted_average, \
11+
fit_gibbs_global, fit_gibbs_global_batch, RatesFitResult, GenericFitResult
12+
import param
13+
import pandas as pd
14+
from pyhdx.support import gen_subclasses
15+
import yaml
716

817
time_factors = {"s": 1, "m": 60.0, "min": 60.0, "h": 3600, "d": 86400}
918
temperature_offsets = {"c": 273.15, "celsius": 273.15, "k": 0, "kelvin": 0}
1019

11-
# todo add data filters in yaml spec
20+
21+
# todo add data filters in state spec?
1222
# todo add proline, n_term options
13-
class YamlParser(object):
14-
""'object used to parse yaml data input files into PyHDX HDX Measurement object'
23+
class StateParser(object):
24+
""'object used to parse yaml state input files into PyHDX HDX Measurement object'
1525

16-
def __init__(self, yaml_dict, data_src=None, data_filters=None):
17-
self.yaml_dict = yaml_dict
26+
# todo yaml_dict -> state_spec
27+
def __init__(self, state_spec, data_src=None, data_filters=None):
28+
self.state_spec = state_spec
1829
if isinstance(data_src, (os.PathLike, str)):
1930
self.data_src = Path(data_src)
2031
elif isinstance(data_src, dict):
@@ -44,7 +55,7 @@ def load_data(self, *filenames, reader='dynamx'):
4455
def load_hdxmset(self):
4556
"""batch read the full yaml spec into a hdxmeasurementset"""
4657
hdxm_list = []
47-
for state in self.yaml_dict.keys():
58+
for state in self.state_spec.keys():
4859
hdxm = self.load_hdxm(state, name=state)
4960
hdxm_list.append(hdxm)
5061

@@ -55,7 +66,7 @@ def load_hdxm(self, state, **kwargs):
5566
kwargs: additional kwargs passed to hdxmeasurementset
5667
"""
5768

58-
state_dict = self.yaml_dict[state]
69+
state_dict = self.state_spec[state]
5970

6071
filenames = state_dict["filenames"]
6172
df = self.load_data(*filenames)
@@ -95,8 +106,8 @@ def load_hdxm(self, state, **kwargs):
95106
raise ValueError("Must specify either 'c_term' or 'sequence'")
96107

97108
state_data = pmt.get_state(state_dict["state"])
98-
for filter in self.data_filters:
99-
state_data = filter(state_data)
109+
for flt in self.data_filters:
110+
state_data = flt(state_data)
100111

101112
hdxm = HDXMeasurement(
102113
state_data,
@@ -111,16 +122,170 @@ def load_hdxm(self, state, **kwargs):
111122
return hdxm
112123

113124

125+
process_functions = {
126+
'csv_to_dataframe': csv_to_dataframe,
127+
'fit_rates_half_time_interpolate': fit_rates_half_time_interpolate,
128+
'fit_rates_weighted_average': fit_rates_weighted_average,
129+
'fit_gibbs_global': fit_gibbs_global
130+
131+
}
132+
133+
# task objects should be param
134+
class Task(param.Parameterized):
135+
...
136+
137+
scheduler_address = param.String(doc='Optional scheduler adress for dask task')
138+
139+
cwd = param.ClassSelector(Path, doc='Path of the current working directory')
140+
141+
142+
class LoadHDMeasurementSetTask(Task):
143+
_type = 'load_hdxm_set'
144+
145+
state_file = param.String() # = string path
146+
147+
out = param.ClassSelector(HDXMeasurementSet)
148+
149+
def execute(self, *args, **kwargs):
150+
print(self.cwd)
151+
state_spec = yaml.safe_load((self.cwd / self.state_file).read_text())
152+
parser = StateParser(state_spec, self.cwd, default_filters)
153+
hdxm_set = parser.load_hdxmset()
154+
155+
self.out = hdxm_set
156+
157+
158+
class EstimateRates(Task):
159+
_type = 'estimate_rates'
160+
161+
hdxm_set = param.ClassSelector(HDXMeasurementSet)
162+
163+
select_state = param.String(doc='If set, only use this state for creating initial guesses')
164+
165+
out = param.ClassSelector((RatesFitResult, GenericFitResult))
166+
167+
def execute(self, *args, **kwargs):
168+
if self.select_state: # refactor to 'state' ?
169+
hdxm = self.hdxm_set.get(self.select_state)
170+
result = fit_rates_half_time_interpolate(hdxm)
171+
else:
172+
results = []
173+
for hdxm in self.hdxm_set:
174+
r = fit_rates_half_time_interpolate(hdxm)
175+
results.append(r)
176+
result = RatesFitResult(results)
177+
178+
self.out = result
179+
180+
181+
# todo allow guesses from deltaG
182+
class ProcessGuesses(Task):
183+
_type = 'create_guess'
184+
185+
hdxm_set = param.ClassSelector(HDXMeasurementSet)
186+
187+
select_state = param.String(doc='If set, only use this state for creating initial guesses')
188+
189+
rates_df = param.ClassSelector(pd.DataFrame)
190+
191+
out = param.ClassSelector((pd.Series, pd.DataFrame))
192+
193+
def execute(self, *args, **kwargs):
194+
if self.select_state:
195+
hdxm = self.hdxm_set.get(self.select_state)
196+
if self.rates_df.columns.nlevels == 2:
197+
rates_series = self.rates_df[(self.select_state, 'rate')]
198+
else:
199+
rates_series = self.rates_df['rate']
200+
201+
guess = hdxm.guess_deltaG(rates_series)
202+
203+
else:
204+
rates = self.rates_df.xs('rate', level=-1, axis=1)
205+
guess = self.hdxm_set.guess_deltaG(rates)
206+
207+
self.out = guess
208+
209+
210+
class FitGlobalBatch(Task):
211+
_type = 'fit_global_batch'
212+
213+
hdxm_set = param.ClassSelector(HDXMeasurementSet)
214+
215+
initial_guess = param.ClassSelector(
216+
(pd.Series, pd.DataFrame), doc='Initial guesses for fits')
217+
218+
out = param.ClassSelector(TorchFitResult)
219+
220+
def execute(self, *args, **kwargs):
221+
result = fit_gibbs_global_batch(self.hdxm_set, self.initial_guess, **kwargs)
222+
223+
self.out = result
224+
225+
226+
class SaveFitResult(Task):
227+
_type = 'save_fit_result'
228+
229+
fit_result = param.ClassSelector(TorchFitResult)
230+
231+
output_dir = param.String()
232+
233+
def execute(self, *args, **kwargs):
234+
save_fitresult(self.cwd / self.output_dir, self.fit_result)
235+
236+
237+
class JobParser(object):
238+
239+
cwd = param.ClassSelector(Path, doc='Path of the current working directory')
240+
241+
def __init__(self, job_spec, cwd=None, ):
242+
self.job_spec = job_spec
243+
self.cwd = cwd or Path().cwd()
244+
245+
self.tasks = {}
246+
self.task_classes = {cls._type: cls for cls in gen_subclasses(Task) if getattr(cls, "_type", None)}
247+
248+
def resolve_var(self, var_string):
249+
task_name, *attrs = var_string.split('.')
250+
251+
return reduce(getattr, attrs, self.tasks[task_name])
252+
253+
def execute(self):
254+
255+
for task_spec in self.job_spec['steps']:
256+
task_klass = self.task_classes[task_spec['task']]
257+
skip = {'args', 'kwargs', 'task'}
258+
259+
resolved_params = {}
260+
for par_name in task_spec.keys() - skip:
261+
value = task_spec[par_name]
262+
if isinstance(value, str):
263+
m = re.findall(r'\$\((.*?)\)', value)
264+
if m:
265+
value = self.resolve_var(m[0])
266+
resolved_params[par_name] = value
267+
task = task_klass(cwd=self.cwd, **resolved_params)
268+
task.execute(*task_spec.get('args', []), **task_spec.get('kwargs', {}))
269+
270+
self.tasks[task.name] = task
271+
272+
114273
def yaml_to_hdxmset(yaml_dict, data_dir=None, **kwargs):
115274
"""reads files according to `yaml_dict` spec from `data_dir into HDXMEasurementSet"""
116275

276+
warnings.warn("yaml_to_hdxmset is deprecated, use 'StateParser'")
117277
hdxm_list = []
118278
for k, v in yaml_dict.items():
119279
hdxm = yaml_to_hdxm(v, data_dir=data_dir, name=k)
120280
hdxm_list.append(hdxm)
121281

122282
return HDXMeasurementSet(hdxm_list)
123283

284+
# todo configurable
285+
default_filters = [
286+
lambda df: df.query('exposure > 0')
287+
]
288+
124289

125290
def yaml_to_hdxm(yaml_dict, data_dir=None, data_filters=None, **kwargs):
126291
# todo perhas classmethod on HDXMeasurement object?
@@ -142,7 +307,7 @@ def yaml_to_hdxm(yaml_dict, data_dir=None, data_filters=None, **kwargs):
142307
Output data object as specified by `yaml_dict`.
143308
"""
144309

145-
warnings.warn('This method is deprecated in favor of YamlParser', DeprecationWarning)
310+
warnings.warn('This method is deprecated in favor of StateParser', DeprecationWarning)
146311

147312
if data_dir is not None:
148313
input_files = [Path(data_dir) / fname for fname in yaml_dict["filenames"]]
@@ -270,3 +435,5 @@ def load_from_yaml_v040b2(yaml_dict, data_dir=None, **kwargs): # pragma: no cov
270435
)
271436

272437
return hdxm
438+
439+

pyhdx/fileIO.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -371,7 +371,7 @@ def save_fitresult(output_dir, fit_result, log_lines=None):
371371
dataframe_to_file(output_dir / "losses.csv", fit_result.losses)
372372
dataframe_to_file(output_dir / "losses.txt", fit_result.losses, fmt="pprint")
373373

374-
if isinstance(fit_result.hdxm_set, pyhdx.HDXMeasurement):
374+
if isinstance(fit_result.hdxm_set, pyhdx.HDXMeasurement): # check, but this should always be hdxm_set
375375
fit_result.hdxm_set.to_file(output_dir / "HDXMeasurement.csv")
376376
if isinstance(fit_result.hdxm_set, pyhdx.HDXMeasurementSet):
377377
fit_result.hdxm_set.to_file(output_dir / "HDXMeasurements.csv")

pyhdx/fitting.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,7 @@ class GenericFitResult:
929929

930930
@dataclass
931931
class RatesFitResult:
932+
"""Accumulates multiple Generic/KineticsFit Results"""
932933
results: list
933934

934935
@property

pyhdx/models.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1219,6 +1219,12 @@ def __iter__(self):
12191219
def __getitem__(self, item):
12201220
return self.hdxm_list.__getitem__(item)
12211221

1222+
def get(self, name):
1223+
"""find a HDXMeasurement by name"""
1224+
1225+
idx = self.names.index(name)
1226+
return self[idx]
1227+
12221228
@property
12231229
def Ns(self):
12241230
return len(self.hdxm_list)

pyhdx/web/controllers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from proplot import to_hex
2121
from skimage.filters import threshold_multiotsu
2222

23-
from pyhdx.batch_processing import YamlParser
23+
from pyhdx.batch_processing import StateParser
2424
from pyhdx.config import cfg
2525
from pyhdx.fileIO import read_dynamx, csv_to_dataframe, dataframe_to_stringio
2626
from pyhdx.fitting import (
@@ -499,7 +499,7 @@ def _add_dataset_batch(self):
499499
ios = {name: StringIO(byte_content.decode("UTF-8")) for name, byte_content in zip(self.widgets['input_files'].filename, self.input_files)}
500500
filters = [lambda df: df.query('exposure > 0')]
501501

502-
parser = YamlParser(yaml_dict, data_src=ios, data_filters=filters)
502+
parser = StateParser(yaml_dict, data_src=ios, data_filters=filters)
503503

504504
for state in yaml_dict.keys():
505505
hdxm = parser.load_hdxm(state, name=state)

setup.cfg

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ install_requires =
3434
dask
3535
distributed
3636
packaging
37+
param
38+
3739
python_requires =
3840
>=3.8
3941

@@ -46,7 +48,6 @@ console_scripts =
4648
web =
4749
panel>=0.12.6
4850
bokeh
49-
param
5051
holoviews
5152
colorcet >= 3.0.0
5253
hvplot

templates/02_guesses_from_yaml.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Load HDX-MS data from yaml spec and perform initial guess of exchange rates"""
2-
from pyhdx.batch_processing import YamlParser
2+
from pyhdx.batch_processing import StateParser
33
from pathlib import Path
44
from pyhdx.fitting import fit_rates_weighted_average
55
import yaml
@@ -16,7 +16,7 @@
1616
# Requires local_cluster.py to be running (or other Dask client on default address in config)
1717
client = default_client()
1818

19-
parser = YamlParser(data_dict, data_src=data_dir)
19+
parser = StateParser(data_dict, data_src=data_dir)
2020
for name in data_dict.keys():
2121
print(name)
2222
dic = data_dict[name]

templates/06_fitting_with_logs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Perform fitting with a range of regularizers"""
2-
from pyhdx.batch_processing import yaml_to_hdxmset, YamlParser
2+
from pyhdx.batch_processing import yaml_to_hdxmset, StateParser
33
from pathlib import Path
44
from pyhdx.fitting import fit_gibbs_global_batch
55
import yaml
@@ -20,7 +20,7 @@
2020
output_dir = current_dir / 'fit'
2121
output_dir.mkdir(exist_ok=True)
2222

23-
parser = YamlParser(data_dict, data_src=input_dir)
23+
parser = StateParser(data_dict, data_src=input_dir)
2424
hdx_set = parser.load_hdxmset()
2525

2626
rates_list = [csv_to_protein(current_dir / 'guesses' / f'{name}_rates_guess.csv')['rate'] for name in data_dict.keys()]

templates/12_jobfiles.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""
2+
Execute a PyHDX data processing pipeline according to a yaml jobfile specification
3+
4+
5+
"""
6+
7+
8+
from pathlib import Path
9+
from pyhdx.batch_processing import JobParser
10+
import yaml
11+
12+
#%%
13+
# Pycharm scientific mode
14+
if '__file__' not in locals():
15+
__file__ = Path().cwd() / 'templates' / 'script.py'
16+
17+
current_dir = Path(__file__).parent
18+
output_dir = current_dir / 'output'
19+
output_dir.mkdir(exist_ok=True)
20+
test_data_dir = current_dir.parent / 'tests' / 'test_data'
21+
input_dir = test_data_dir / 'input'
22+
23+
#%%
24+
25+
job_spec = yaml.safe_load((input_dir / 'jobfile.yaml').read_text())
26+
job_parser = JobParser(job_spec, cwd=input_dir)
27+
job_parser.execute()

tests/test_batchprocessing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from pyhdx.batch_processing import yaml_to_hdxm, yaml_to_hdxmset, YamlParser
1+
from pyhdx.batch_processing import yaml_to_hdxm, yaml_to_hdxmset, StateParser
22
from pyhdx.models import HDXMeasurement, HDXMeasurementSet
33
import numpy as np
44
from pathlib import Path
@@ -27,7 +27,7 @@ def test_load_from_yaml(self):
2727
assert isinstance(hdxm_set, HDXMeasurementSet)
2828
assert hdxm_set.names == list(data_dict.keys())
2929

30-
parser = YamlParser(data_dict, data_src=input_dir)
30+
parser = StateParser(data_dict, data_src=input_dir)
3131

3232
hdxm = parser.load_hdxm('SecB_tetramer')
3333
assert isinstance(hdxm, HDXMeasurement)

0 commit comments

Comments
 (0)