|
17 | 17 | from tqdm import tqdm |
18 | 18 |
|
19 | 19 | from pyhdx.config import cfg |
20 | | -from pyhdx.fileIO import load_fitresult |
21 | 20 | from pyhdx.support import ( |
22 | 21 | apply_cmap, |
23 | 22 | autowrap, |
@@ -207,7 +206,7 @@ def residue_time_scatter_figure( |
207 | 206 | nrows = figure_kwargs.pop("nrows", int(np.ceil(n_subplots / ncols))) |
208 | 207 | figure_width = figure_kwargs.pop("width", cfg.plotting.page_width) / 25.4 |
209 | 208 | refaspect = figure_kwargs.pop("refaspect", cfg.plotting.residue_scatter_aspect) |
210 | | - cbar_width = figure_kwargs.pop("cbar_width", cfg.plotting.cbar_width) / 25.4 |
| 209 | + # cbar_width = figure_kwargs.pop("cbar_width", cfg.plotting.cbar_width) / 25.4 |
211 | 210 |
|
212 | 211 | cmap = uplt.Colormap(cmap) # todo allow None as cmap |
213 | 212 | norm = norm or uplt.Norm("linear", vmin=0, vmax=1) |
@@ -1654,204 +1653,3 @@ def plot_all(self, **kwargs): |
1654 | 1653 | pbar.set_description(plot_type) |
1655 | 1654 | fig_kwargs = kwargs.get(plot_type, {}) |
1656 | 1655 | self.save_figure(plot_type, **fig_kwargs) |
1657 | | - |
1658 | | - |
1659 | | -def plot_fitresults( |
1660 | | - fitresult_path, |
1661 | | - reference=None, |
1662 | | - plots="all", |
1663 | | - renew=False, |
1664 | | - cmap_and_norm=None, |
1665 | | - output_path=None, |
1666 | | - output_type=".png", |
1667 | | - **save_kwargs, |
1668 | | -): |
1669 | | - """ |
1670 | | -
|
1671 | | - Parameters |
1672 | | - ---------- |
1673 | | - fitresult_path |
1674 | | - plots |
1675 | | - renew |
1676 | | - cmap_and_norm: :obj:`dict`, optional |
1677 | | - Dictionary with cmap and norms to use. If `None`, reverts to defaults. |
1678 | | - Dict format: {'dG': (cmap, norm), 'ddG': (cmap, norm)} |
1679 | | -
|
1680 | | - output_type: list or str |
1681 | | -
|
1682 | | - Returns |
1683 | | - ------- |
1684 | | -
|
1685 | | - """ |
1686 | | - |
1687 | | - raise DeprecationWarning("This function is deprecated, use FitResultPlot.plot_all instead") |
1688 | | - # batch results only |
1689 | | - history_path = fitresult_path / "model_history.csv" |
1690 | | - output_path = output_path or fitresult_path |
1691 | | - output_type = list([output_type]) if isinstance(output_type, str) else output_type |
1692 | | - fitresult = load_fitresult(fitresult_path) |
1693 | | - |
1694 | | - protein_states = fitresult.output.df.columns.get_level_values(0).unique() |
1695 | | - |
1696 | | - if isinstance(reference, int): |
1697 | | - reference_state = protein_states[reference] |
1698 | | - elif reference in protein_states: |
1699 | | - reference_state = reference |
1700 | | - elif reference is None: |
1701 | | - reference_state = None |
1702 | | - else: |
1703 | | - raise ValueError(f"Invalid value {reference!r} for 'reference'") |
1704 | | - |
1705 | | - # todo needs tidying up |
1706 | | - cmap_and_norm = cmap_and_norm or {} |
1707 | | - dG_cmap, dG_norm = cmap_and_norm.get("dG", (None, None)) |
1708 | | - dG_cmap_default, dG_norm_default = default_cmap_norm("dG") |
1709 | | - ddG_cmap, ddG_norm = cmap_and_norm.get("ddG", (None, None)) |
1710 | | - ddG_cmap_default, ddG_norm_default = default_cmap_norm("ddG") |
1711 | | - dG_cmap = ddG_cmap or dG_cmap_default |
1712 | | - dG_norm = dG_norm or dG_norm_default |
1713 | | - ddG_cmap = ddG_cmap or ddG_cmap_default |
1714 | | - ddG_norm = ddG_norm or ddG_norm_default |
1715 | | - |
1716 | | - # check_exists = lambda x: False if renew else x.exists() |
1717 | | - # todo add logic for checking renew or not |
1718 | | - |
1719 | | - if plots == "all": |
1720 | | - plots = [ |
1721 | | - "loss", |
1722 | | - "rfu_coverage", |
1723 | | - "rfu_scatter", |
1724 | | - "dG_scatter", |
1725 | | - "ddG_scatter", |
1726 | | - "linear_bars", |
1727 | | - "rainbowclouds", |
1728 | | - "peptide_mse", |
1729 | | - ] |
1730 | | - |
1731 | | - # def check_update(pth, fname, extensions, renew): |
1732 | | - # # Returns True if the target graph should be renewed or not |
1733 | | - # if renew: |
1734 | | - # return True |
1735 | | - # else: |
1736 | | - # pths = [pth / (fname + ext) for ext in extensions] |
1737 | | - # return any([not pth.exists() for pth in pths]) |
1738 | | - |
1739 | | - # plots = [p for p in plots if check_update(output_path, p, output_type, renew)] |
1740 | | - |
1741 | | - if "loss" in plots: |
1742 | | - loss_df = fitresult.losses |
1743 | | - loss_df.plot() |
1744 | | - |
1745 | | - mse_loss = loss_df["mse_loss"] |
1746 | | - reg_loss = loss_df.iloc[:, 1:].sum(axis=1) |
1747 | | - reg_percentage = 100 * reg_loss / (mse_loss + reg_loss) |
1748 | | - fig = plt.gcf() |
1749 | | - ax = plt.gca() |
1750 | | - ax1 = ax.twinx() |
1751 | | - reg_percentage.plot(ax=ax1, color="k") |
1752 | | - ax1.set_xlim(0, None) |
1753 | | - for ext in output_type: |
1754 | | - f_out = output_path / ("loss" + ext) |
1755 | | - plt.savefig(f_out) |
1756 | | - plt.close(fig) |
1757 | | - |
1758 | | - if "rfu_coverage" in plots: |
1759 | | - for hdxm in fitresult.hdxm_set: |
1760 | | - fig, axes, cbar_ax = peptide_coverage_figure(hdxm.data) |
1761 | | - for ext in output_type: |
1762 | | - f_out = output_path / (f"rfu_coverage_{hdxm.name}" + ext) |
1763 | | - plt.savefig(f_out) |
1764 | | - plt.close(fig) |
1765 | | - |
1766 | | - # todo rfu_scatter_timepoint |
1767 | | - |
1768 | | - if "rfu_scatter" in plots: |
1769 | | - fig, axes, cbar = residue_scatter_figure(fitresult.hdxm_set) |
1770 | | - for ext in output_type: |
1771 | | - f_out = output_path / ("rfu_scatter" + ext) |
1772 | | - plt.savefig(f_out) |
1773 | | - plt.close(fig) |
1774 | | - |
1775 | | - if "dG_scatter" in plots: |
1776 | | - fig, axes, cbars = dG_scatter_figure(fitresult.output.df, cmap=dG_cmap, norm=dG_norm) |
1777 | | - for ext in output_type: |
1778 | | - f_out = output_path / ("dG_scatter" + ext) |
1779 | | - plt.savefig(f_out) |
1780 | | - plt.close(fig) |
1781 | | - |
1782 | | - if "ddG_scatter" in plots: |
1783 | | - fig, axes, cbars = ddG_scatter_figure( |
1784 | | - fitresult.output.df, reference=reference, cmap=ddG_cmap, norm=ddG_norm |
1785 | | - ) |
1786 | | - for ext in output_type: |
1787 | | - f_out = output_path / ("ddG_scatter" + ext) |
1788 | | - plt.savefig(f_out) |
1789 | | - plt.close(fig) |
1790 | | - |
1791 | | - if "linear_bars" in plots: |
1792 | | - fig, axes = linear_bars_figure(fitresult.output.df) |
1793 | | - for ext in output_type: |
1794 | | - f_out = output_path / ("dG_linear_bars" + ext) |
1795 | | - plt.savefig(f_out) |
1796 | | - plt.close(fig) |
1797 | | - |
1798 | | - if reference_state: |
1799 | | - fig, axes = linear_bars_figure(fitresult.output.df, reference=reference) |
1800 | | - for ext in output_type: |
1801 | | - f_out = output_path / ("ddG_linear_bars" + ext) |
1802 | | - plt.savefig(f_out) |
1803 | | - plt.close(fig) |
1804 | | - |
1805 | | - if "rainbowclouds" in plots: |
1806 | | - fig, ax = rainbowclouds_figure(fitresult.output.df) |
1807 | | - for ext in output_type: |
1808 | | - f_out = output_path / ("dG_rainbowclouds" + ext) |
1809 | | - plt.savefig(f_out) |
1810 | | - plt.close(fig) |
1811 | | - |
1812 | | - if reference_state: |
1813 | | - fig, axes = rainbowclouds_figure(fitresult.output.df, reference=reference) |
1814 | | - for ext in output_type: |
1815 | | - f_out = output_path / ("ddG_rainbowclouds" + ext) |
1816 | | - plt.savefig(f_out) |
1817 | | - plt.close(fig) |
1818 | | - |
1819 | | - if "peptide_mse" in plots: |
1820 | | - fig, axes, cbars = peptide_mse_figure(fitresult.get_peptide_mse()) |
1821 | | - for ext in output_type: |
1822 | | - f_out = output_path / ("peptide_mse" + ext) |
1823 | | - plt.savefig(f_out) |
1824 | | - plt.close(fig) |
1825 | | - |
1826 | | - # |
1827 | | - # if 'history' in plots: |
1828 | | - # for h_df, name in zip(history_list, names): |
1829 | | - # output_path = fitresult_path / f'{name}history.png' |
1830 | | - # if check_exists(output_path): |
1831 | | - # break |
1832 | | - # |
1833 | | - # num = len(h_df.columns) |
1834 | | - # max_epochs = max([int(c) for c in h_df.columns]) |
1835 | | - # |
1836 | | - # cmap = mpl.cm.get_cmap('winter') |
1837 | | - # norm = mpl.colors.Normalize(vmin=1, vmax=max_epochs) |
1838 | | - # colors = iter(cmap(np.linspace(0, 1, num=num))) |
1839 | | - # |
1840 | | - # fig, axes = uplt.subplots(nrows=1, width=width, aspect=aspect) |
1841 | | - # ax = axes[0] |
1842 | | - # for key in h_df: |
1843 | | - # c = next(colors) |
1844 | | - # to_hex(c) |
1845 | | - # |
1846 | | - # ax.scatter(h_df.index, h_df[key] * 1e-3, color=to_hex(c), **scatter_kwargs) |
1847 | | - # ax.format(xlabel=r_xlabel, ylabel=dG_ylabel) |
1848 | | - # |
1849 | | - # values = np.linspace(0, max_epochs, endpoint=True, num=num) |
1850 | | - # colors = cmap(norm(values)) |
1851 | | - # tick_labels = np.linspace(0, max_epochs, num=5) |
1852 | | - # |
1853 | | - # cbar = fig.colorbar(colors, values=values, ticks=tick_labels, space=0, width=cbar_width, label='Epochs') |
1854 | | - # ax.format(yticklabelloc='None', ytickloc='None') |
1855 | | - # |
1856 | | - # plt.savefig(output_path) |
1857 | | - # plt.close(fig) |
0 commit comments