Skip to content

Commit 47b36f7

Browse files
committed
linting
1 parent fd2bf92 commit 47b36f7

File tree

16 files changed

+88
-313
lines changed

16 files changed

+88
-313
lines changed

pyhdx/datasets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from hdxms_datasets import *
1+
from hdxms_datasets import * # noqa F403

pyhdx/fileIO.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,17 @@
55
import os
66
import re
77
import shutil
8+
import warnings
89
from datetime import datetime
9-
from io import StringIO, BytesIO
10-
from pathlib import Path
11-
from typing import Union, Literal, Tuple, List, TextIO, Optional, TYPE_CHECKING, Any, BinaryIO
1210
from importlib import import_module
13-
import torch.nn as nn
14-
import torch as t
11+
from io import BytesIO, StringIO
12+
from pathlib import Path
13+
from typing import TYPE_CHECKING, Any, BinaryIO, List, Literal, Optional, TextIO, Tuple, Union
14+
1515
import pandas as pd
16+
import torch as t
17+
import torch.nn as nn
1618
import yaml
17-
import warnings
1819

1920
import pyhdx
2021

@@ -217,7 +218,7 @@ def dataframe_to_stringio(
217218
sio.write(prefix + f'{now.strftime("%Y/%m/%d %H:%M:%S")} ({int(now.timestamp())}) \n')
218219

219220
json_header = {}
220-
if include_metadata == True and "metadata" in df.attrs:
221+
if include_metadata is True and "metadata" in df.attrs:
221222
json_header["metadata"] = df.attrs["metadata"]
222223
elif include_metadata and isinstance(include_metadata, dict):
223224
json_header["metadata"] = include_metadata
@@ -366,10 +367,6 @@ def load_fitresult(fit_dir: os.PathLike) -> Union[TorchFitResult, TorchFitResult
366367
result_klass = pyhdx.fitting_torch.TorchFitResult
367368
elif pth.is_file():
368369
raise DeprecationWarning("`load_fitresult` only loads from fit result directories")
369-
fit_result = csv_to_dataframe(fit_dir)
370-
assert isinstance(
371-
hdxm, pyhdx.HDXMeasurement
372-
), "No valid HDXMeasurement data object supplied"
373370
else:
374371
raise ValueError("Specified fit result path is not a directory")
375372

pyhdx/fitting.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,27 +3,27 @@
33
from collections import namedtuple
44
from dataclasses import dataclass
55
from functools import partial
6-
from typing import Union, Optional, Any, Literal
6+
from typing import Any, Literal, Optional, Union
77

88
import numpy as np
99
import pandas as pd
1010
import torch
1111
from dask.distributed import Client, worker_client
12-
from scipy.optimize import Bounds, minimize, OptimizeResult
12+
from scipy.optimize import Bounds, OptimizeResult, minimize
1313
from symfit import Fit
1414
from symfit.core.minimizers import DifferentialEvolution, Powell
1515
from tqdm.auto import tqdm, trange
1616

17+
from pyhdx.config import cfg
1718
from pyhdx.fit_models import (
1819
SingleKineticModel,
1920
TwoComponentAssociationModel,
2021
TwoComponentDissociationModel,
2122
)
2223
from pyhdx.fitting_torch import DeltaGFit, TorchFitResult
2324
from pyhdx.local_cluster import DummyClient
24-
from pyhdx.support import temporary_seed, pbar_decorator, multiindex_astype
25-
from pyhdx.models import HDXMeasurementSet, HDXTimepoint, HDXMeasurement
26-
from pyhdx.config import cfg
25+
from pyhdx.models import HDXMeasurement, HDXMeasurementSet, HDXTimepoint
26+
from pyhdx.support import multiindex_astype, pbar_decorator, temporary_seed
2727

2828
EmptyResult = namedtuple("EmptyResult", ["chi_squared", "params"])
2929
er = EmptyResult(np.nan, {k: np.nan for k in ["tau1", "tau2", "r"]})
@@ -187,7 +187,7 @@ def fit_rates_weighted_average(
187187
if pbar:
188188
raise NotImplementedError()
189189
else:
190-
inc = lambda: None
190+
pass
191191

192192
results = []
193193

@@ -385,9 +385,9 @@ def _fit_single_d_update(
385385
# d_uptake = hdx_t.data["uptake_corrected"].values
386386
Nr = X.shape[1] # number of residues / parameters
387387

388-
if bounds == True:
388+
if bounds is True:
389389
bounds = Bounds(lb=np.zeros(Nr), ub=np.ones(Nr))
390-
elif bounds == False:
390+
elif bounds is False:
391391
bounds = None
392392

393393
args = (X, d_uptake, r1)
@@ -561,7 +561,7 @@ def closure():
561561
iter = trange(epochs) if verbose else range(epochs)
562562
for epoch in iter:
563563
optimizer_obj.zero_grad()
564-
loss = optimizer_obj.step(closure)
564+
loss = optimizer_obj.step(closure) # noqa
565565

566566
for cb in callbacks:
567567
cb(epoch, model, optimizer_obj)

pyhdx/plot.py

Lines changed: 1 addition & 203 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from tqdm import tqdm
1818

1919
from pyhdx.config import cfg
20-
from pyhdx.fileIO import load_fitresult
2120
from pyhdx.support import (
2221
apply_cmap,
2322
autowrap,
@@ -207,7 +206,7 @@ def residue_time_scatter_figure(
207206
nrows = figure_kwargs.pop("nrows", int(np.ceil(n_subplots / ncols)))
208207
figure_width = figure_kwargs.pop("width", cfg.plotting.page_width) / 25.4
209208
refaspect = figure_kwargs.pop("refaspect", cfg.plotting.residue_scatter_aspect)
210-
cbar_width = figure_kwargs.pop("cbar_width", cfg.plotting.cbar_width) / 25.4
209+
# cbar_width = figure_kwargs.pop("cbar_width", cfg.plotting.cbar_width) / 25.4
211210

212211
cmap = uplt.Colormap(cmap) # todo allow None as cmap
213212
norm = norm or uplt.Norm("linear", vmin=0, vmax=1)
@@ -1654,204 +1653,3 @@ def plot_all(self, **kwargs):
16541653
pbar.set_description(plot_type)
16551654
fig_kwargs = kwargs.get(plot_type, {})
16561655
self.save_figure(plot_type, **fig_kwargs)
1657-
1658-
1659-
def plot_fitresults(
1660-
fitresult_path,
1661-
reference=None,
1662-
plots="all",
1663-
renew=False,
1664-
cmap_and_norm=None,
1665-
output_path=None,
1666-
output_type=".png",
1667-
**save_kwargs,
1668-
):
1669-
"""
1670-
1671-
Parameters
1672-
----------
1673-
fitresult_path
1674-
plots
1675-
renew
1676-
cmap_and_norm: :obj:`dict`, optional
1677-
Dictionary with cmap and norms to use. If `None`, reverts to defaults.
1678-
Dict format: {'dG': (cmap, norm), 'ddG': (cmap, norm)}
1679-
1680-
output_type: list or str
1681-
1682-
Returns
1683-
-------
1684-
1685-
"""
1686-
1687-
raise DeprecationWarning("This function is deprecated, use FitResultPlot.plot_all instead")
1688-
# batch results only
1689-
history_path = fitresult_path / "model_history.csv"
1690-
output_path = output_path or fitresult_path
1691-
output_type = list([output_type]) if isinstance(output_type, str) else output_type
1692-
fitresult = load_fitresult(fitresult_path)
1693-
1694-
protein_states = fitresult.output.df.columns.get_level_values(0).unique()
1695-
1696-
if isinstance(reference, int):
1697-
reference_state = protein_states[reference]
1698-
elif reference in protein_states:
1699-
reference_state = reference
1700-
elif reference is None:
1701-
reference_state = None
1702-
else:
1703-
raise ValueError(f"Invalid value {reference!r} for 'reference'")
1704-
1705-
# todo needs tidying up
1706-
cmap_and_norm = cmap_and_norm or {}
1707-
dG_cmap, dG_norm = cmap_and_norm.get("dG", (None, None))
1708-
dG_cmap_default, dG_norm_default = default_cmap_norm("dG")
1709-
ddG_cmap, ddG_norm = cmap_and_norm.get("ddG", (None, None))
1710-
ddG_cmap_default, ddG_norm_default = default_cmap_norm("ddG")
1711-
dG_cmap = ddG_cmap or dG_cmap_default
1712-
dG_norm = dG_norm or dG_norm_default
1713-
ddG_cmap = ddG_cmap or ddG_cmap_default
1714-
ddG_norm = ddG_norm or ddG_norm_default
1715-
1716-
# check_exists = lambda x: False if renew else x.exists()
1717-
# todo add logic for checking renew or not
1718-
1719-
if plots == "all":
1720-
plots = [
1721-
"loss",
1722-
"rfu_coverage",
1723-
"rfu_scatter",
1724-
"dG_scatter",
1725-
"ddG_scatter",
1726-
"linear_bars",
1727-
"rainbowclouds",
1728-
"peptide_mse",
1729-
]
1730-
1731-
# def check_update(pth, fname, extensions, renew):
1732-
# # Returns True if the target graph should be renewed or not
1733-
# if renew:
1734-
# return True
1735-
# else:
1736-
# pths = [pth / (fname + ext) for ext in extensions]
1737-
# return any([not pth.exists() for pth in pths])
1738-
1739-
# plots = [p for p in plots if check_update(output_path, p, output_type, renew)]
1740-
1741-
if "loss" in plots:
1742-
loss_df = fitresult.losses
1743-
loss_df.plot()
1744-
1745-
mse_loss = loss_df["mse_loss"]
1746-
reg_loss = loss_df.iloc[:, 1:].sum(axis=1)
1747-
reg_percentage = 100 * reg_loss / (mse_loss + reg_loss)
1748-
fig = plt.gcf()
1749-
ax = plt.gca()
1750-
ax1 = ax.twinx()
1751-
reg_percentage.plot(ax=ax1, color="k")
1752-
ax1.set_xlim(0, None)
1753-
for ext in output_type:
1754-
f_out = output_path / ("loss" + ext)
1755-
plt.savefig(f_out)
1756-
plt.close(fig)
1757-
1758-
if "rfu_coverage" in plots:
1759-
for hdxm in fitresult.hdxm_set:
1760-
fig, axes, cbar_ax = peptide_coverage_figure(hdxm.data)
1761-
for ext in output_type:
1762-
f_out = output_path / (f"rfu_coverage_{hdxm.name}" + ext)
1763-
plt.savefig(f_out)
1764-
plt.close(fig)
1765-
1766-
# todo rfu_scatter_timepoint
1767-
1768-
if "rfu_scatter" in plots:
1769-
fig, axes, cbar = residue_scatter_figure(fitresult.hdxm_set)
1770-
for ext in output_type:
1771-
f_out = output_path / ("rfu_scatter" + ext)
1772-
plt.savefig(f_out)
1773-
plt.close(fig)
1774-
1775-
if "dG_scatter" in plots:
1776-
fig, axes, cbars = dG_scatter_figure(fitresult.output.df, cmap=dG_cmap, norm=dG_norm)
1777-
for ext in output_type:
1778-
f_out = output_path / ("dG_scatter" + ext)
1779-
plt.savefig(f_out)
1780-
plt.close(fig)
1781-
1782-
if "ddG_scatter" in plots:
1783-
fig, axes, cbars = ddG_scatter_figure(
1784-
fitresult.output.df, reference=reference, cmap=ddG_cmap, norm=ddG_norm
1785-
)
1786-
for ext in output_type:
1787-
f_out = output_path / ("ddG_scatter" + ext)
1788-
plt.savefig(f_out)
1789-
plt.close(fig)
1790-
1791-
if "linear_bars" in plots:
1792-
fig, axes = linear_bars_figure(fitresult.output.df)
1793-
for ext in output_type:
1794-
f_out = output_path / ("dG_linear_bars" + ext)
1795-
plt.savefig(f_out)
1796-
plt.close(fig)
1797-
1798-
if reference_state:
1799-
fig, axes = linear_bars_figure(fitresult.output.df, reference=reference)
1800-
for ext in output_type:
1801-
f_out = output_path / ("ddG_linear_bars" + ext)
1802-
plt.savefig(f_out)
1803-
plt.close(fig)
1804-
1805-
if "rainbowclouds" in plots:
1806-
fig, ax = rainbowclouds_figure(fitresult.output.df)
1807-
for ext in output_type:
1808-
f_out = output_path / ("dG_rainbowclouds" + ext)
1809-
plt.savefig(f_out)
1810-
plt.close(fig)
1811-
1812-
if reference_state:
1813-
fig, axes = rainbowclouds_figure(fitresult.output.df, reference=reference)
1814-
for ext in output_type:
1815-
f_out = output_path / ("ddG_rainbowclouds" + ext)
1816-
plt.savefig(f_out)
1817-
plt.close(fig)
1818-
1819-
if "peptide_mse" in plots:
1820-
fig, axes, cbars = peptide_mse_figure(fitresult.get_peptide_mse())
1821-
for ext in output_type:
1822-
f_out = output_path / ("peptide_mse" + ext)
1823-
plt.savefig(f_out)
1824-
plt.close(fig)
1825-
1826-
#
1827-
# if 'history' in plots:
1828-
# for h_df, name in zip(history_list, names):
1829-
# output_path = fitresult_path / f'{name}history.png'
1830-
# if check_exists(output_path):
1831-
# break
1832-
#
1833-
# num = len(h_df.columns)
1834-
# max_epochs = max([int(c) for c in h_df.columns])
1835-
#
1836-
# cmap = mpl.cm.get_cmap('winter')
1837-
# norm = mpl.colors.Normalize(vmin=1, vmax=max_epochs)
1838-
# colors = iter(cmap(np.linspace(0, 1, num=num)))
1839-
#
1840-
# fig, axes = uplt.subplots(nrows=1, width=width, aspect=aspect)
1841-
# ax = axes[0]
1842-
# for key in h_df:
1843-
# c = next(colors)
1844-
# to_hex(c)
1845-
#
1846-
# ax.scatter(h_df.index, h_df[key] * 1e-3, color=to_hex(c), **scatter_kwargs)
1847-
# ax.format(xlabel=r_xlabel, ylabel=dG_ylabel)
1848-
#
1849-
# values = np.linspace(0, max_epochs, endpoint=True, num=num)
1850-
# colors = cmap(norm(values))
1851-
# tick_labels = np.linspace(0, max_epochs, num=5)
1852-
#
1853-
# cbar = fig.colorbar(colors, values=values, ticks=tick_labels, space=0, width=cbar_width, label='Epochs')
1854-
# ax.format(yticklabelloc='None', ytickloc='None')
1855-
#
1856-
# plt.savefig(output_path)
1857-
# plt.close(fig)

pyhdx/tol_colors.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -551,7 +551,7 @@ def __rainbow_discrete(self, lut=None):
551551
28,
552552
],
553553
]
554-
if lut == None or lut < 1 or lut > 23:
554+
if lut == None or lut < 1 or lut > 23: # noqa: E711
555555
lut = 22
556556
self.cmap = discretemap(self.cname, [clrs[i] for i in indexes[lut - 1]])
557557
if lut == 23:
@@ -585,7 +585,7 @@ def tol_cmap(colormap=None, lut=None):
585585
Parameter lut is ignored for all colormaps except 'rainbow_discrete'.
586586
"""
587587
obj = TOLcmaps()
588-
if colormap == None:
588+
if colormap == None: # noqa: E711
589589
return obj.namelist
590590
if colormap not in obj.namelist:
591591
colormap = "rainbow_PuRd"
@@ -617,7 +617,7 @@ def tol_cset(colorset=None):
617617
"medium-contrast",
618618
"light",
619619
)
620-
if colorset == None:
620+
if colorset == None: # noqa: E711
621621
return namelist
622622
if colorset not in namelist:
623623
colorset = "bright"

0 commit comments

Comments
 (0)