diff --git a/nbs_tests/model/model_interface.ipynb b/nbs_tests/model/model_interface.ipynb index bed41142..53d58c52 100644 --- a/nbs_tests/model/model_interface.ipynb +++ b/nbs_tests/model/model_interface.ipynb @@ -485,7 +485,7 @@ "outputs": [], "source": [ "from peptdeep.model.ms2 import pDeepModel\n", - "from peptdeep.pretrained_models import model_zip" + "from peptdeep.pretrained_models import MODEL_ZIP_FILE_PATH" ] }, { @@ -662,7 +662,7 @@ "source": [ "ms2_model = pDeepModel()\n", "ms2_model.build_from_py_codes(\n", - " model_zip, 'generic/ms2.pth.model.py', \n", + " MODEL_ZIP_FILE_PATH, 'generic/ms2.pth.model.py', \n", " include_model_params_yaml=True\n", ")\n", "\n", diff --git a/nbs_tests/pretrained_models.ipynb b/nbs_tests/pretrained_models.ipynb index ce720f1a..b7cb66ba 100644 --- a/nbs_tests/pretrained_models.ipynb +++ b/nbs_tests/pretrained_models.ipynb @@ -56,7 +56,8 @@ "outputs": [], "source": [ "#| hide\n", - "assert is_model_zip(model_zip)" + "download_models()\n", + "assert is_model_zip(MODEL_ZIP_FILE_PATH)" ] }, { @@ -98,8 +99,8 @@ "outputs": [], "source": [ "#| hide\n", - "assert os.path.isfile(model_zip)\n", - "with ZipFile(model_zip) as _zip:\n", + "assert os.path.isfile(MODEL_ZIP_FILE_PATH)\n", + "with ZipFile(MODEL_ZIP_FILE_PATH) as _zip:\n", " with _zip.open('generic/ms2.pth'):\n", " pass\n", " with _zip.open('generic/rt.pth'):\n", @@ -119,7 +120,8 @@ "outputs": [], "source": [ "#| hide\n", - "from io import StringIO" + "from io import StringIO\n", + "import torch" ] }, { diff --git a/peptdeep/cli.py b/peptdeep/cli.py index 3cec28bf..3fd0275f 100644 --- a/peptdeep/cli.py +++ b/peptdeep/cli.py @@ -85,10 +85,10 @@ def _gui(port, settings_yaml): help="If overwrite existing model file.", ) def _install_model(model_file, overwrite): - from peptdeep.pretrained_models import download_models, model_url + from peptdeep.pretrained_models import download_models, MODEL_URL if not model_file: - download_models(model_url, overwrite=overwrite) + download_models(MODEL_URL, overwrite=overwrite) else: download_models(model_file, overwrite=overwrite) diff --git a/peptdeep/hla/hla_class1.py b/peptdeep/hla/hla_class1.py index f093d53b..8cbd9d89 100644 --- a/peptdeep/hla/hla_class1.py +++ b/peptdeep/hla/hla_class1.py @@ -8,7 +8,7 @@ import peptdeep.model.building_block as building_block from peptdeep.model.model_interface import ModelInterface, append_nAA_column_if_missing from peptdeep.model.featurize import get_ascii_indices -from peptdeep.pretrained_models import pretrain_dir, download_models, global_settings +from peptdeep.pretrained_models import PRETRAIN_DIR, download_models, global_settings from .hla_utils import ( get_random_sequences, @@ -134,7 +134,7 @@ class HLA1_Binding_Classifier(ModelInterface): _model_zip_name = global_settings["local_hla_model_zip_name"] _model_url = global_settings["hla_model_url"] - _model_zip = os.path.join(pretrain_dir, _model_zip_name) + _model_zip = os.path.join(PRETRAIN_DIR, _model_zip_name) def __init__( self, diff --git a/peptdeep/pretrained_models.py b/peptdeep/pretrained_models.py index 3d6f8ef2..6962b0ba 100644 --- a/peptdeep/pretrained_models.py +++ b/peptdeep/pretrained_models.py @@ -1,18 +1,17 @@ import os -import pathlib import io import sys import pandas as pd -import torch import urllib import socket -import logging import shutil import ssl import typing from pickle import UnpicklingError import torch.multiprocessing as mp +from peptdeep.utils.deprecations import ModuleWithDeprecations + if sys.platform.lower().startswith("linux"): # to prevent `too many open files` bug on Linux mp.set_sharing_strategy("file_system") @@ -28,35 +27,36 @@ from alphabase.peptide.precursor import refine_precursor_df, update_precursor_mz from alphabase.peptide.mobility import mobility_to_ccs_for_df, ccs_to_mobility_for_df -from peptdeep.settings import global_settings, add_user_defined_modifications from peptdeep.utils import logging, process_bar -from peptdeep.settings import global_settings from peptdeep.model.ms2 import ( pDeepModel, normalize_fragment_intensities, - calc_ms2_similarity, ) from peptdeep.model.rt import AlphaRTModel from peptdeep.model.ccs import AlphaCCSModel from peptdeep.model.charge import ChargeModelForAASeq, ChargeModelForModAASeq -from peptdeep.utils import uniform_sampling, evaluate_linear_regression +from peptdeep.utils import uniform_sampling from peptdeep.settings import global_settings, update_global_settings -pretrain_dir = os.path.join( +PRETRAIN_DIR = os.path.join( os.path.join( os.path.expanduser(global_settings["PEPTDEEP_HOME"]), "pretrained_models" ) ) -if not os.path.exists(pretrain_dir): - os.makedirs(pretrain_dir) -model_zip_name = global_settings["local_model_zip_name"] -model_url = global_settings["model_url"] +sys.modules[__name__].__class__ = ModuleWithDeprecations + +LOCAL_MODAL_ZIP_NAME = global_settings["local_model_zip_name"] +MODEL_URL = global_settings["model_url"] +MODEL_ZIP_FILE_PATH = os.path.join(PRETRAIN_DIR, LOCAL_MODAL_ZIP_NAME) -model_zip = os.path.join(pretrain_dir, model_zip_name) +ModuleWithDeprecations.deprecate(__name__, "pretrain_dir", "PRETRAIN_DIR") +ModuleWithDeprecations.deprecate(__name__, "model_zip_name", "LOCAL_MODAL_ZIP_NAME") +ModuleWithDeprecations.deprecate(__name__, "model_url", "MODEL_URL") +ModuleWithDeprecations.deprecate(__name__, "model_zip", "MODEL_ZIP_FILE_PATH") def is_model_zip(downloaded_zip): @@ -64,7 +64,7 @@ def is_model_zip(downloaded_zip): return any(x == "generic/ms2.pth" for x in zip.namelist()) -def download_models(url: str = model_url, target_path: str = model_zip, overwrite=True): +def download_models(url: str = MODEL_URL, target_path: str = MODEL_ZIP_FILE_PATH): """ Parameters ---------- @@ -74,7 +74,7 @@ def download_models(url: str = model_url, target_path: str = model_zip, overwrit target_path : str, optional Target file path after download. - Defaults to :data:`peptdeep.pretrained_models.model_zip` + Defaults to :data:`peptdeep.pretrained_models.MODEL_ZIP_FILE_PATH` overwrite : bool, optional overwirte old model files. @@ -97,7 +97,7 @@ def download_models(url: str = model_url, target_path: str = model_zip, overwrit "Downloading model failed! Please download the " f'zip or tar file by yourself from "{url}",' " and use \n" - f'"peptdeep --install-model /path/to/{model_zip_name}.zip"\n' + f'"peptdeep --install-model /path/to/{LOCAL_MODAL_ZIP_NAME}.zip"\n' " to install the models" ) else: @@ -105,8 +105,14 @@ def download_models(url: str = model_url, target_path: str = model_zip, overwrit logging.info(f"The pretrained models had been downloaded in {target_path}") -if not os.path.exists(model_zip): - download_models() +def _download_models(model_zip_file_path: str) -> None: + """Download models if not done yet.""" + os.makedirs(PRETRAIN_DIR, exist_ok=True) + if not os.path.exists(model_zip_file_path): + download_models() + if not is_model_zip(model_zip_file_path): + raise ValueError(f"Local model file is not a valid zip: {model_zip_file_path}") + model_mgr_settings = global_settings["model_mgr"] @@ -185,32 +191,39 @@ def _sample(psm_df, n): def load_phos_models(mask_modloss=True): + _download_models(MODEL_ZIP_FILE_PATH) ms2_model = pDeepModel(mask_modloss=mask_modloss) - ms2_model.load(model_zip, model_path_in_zip="phospho/ms2_phos.pth") + ms2_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="phospho/ms2_phos.pth") rt_model = AlphaRTModel() - rt_model.load(model_zip, model_path_in_zip="phospho/rt_phos.pth") + rt_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="phospho/rt_phos.pth") ccs_model = AlphaCCSModel() - ccs_model.load(model_zip, model_path_in_zip="generic/ccs.pth") + ccs_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ccs.pth") return ms2_model, rt_model, ccs_model def load_models(mask_modloss=True): + _download_models(MODEL_ZIP_FILE_PATH) ms2_model = pDeepModel(mask_modloss=mask_modloss) - ms2_model.load(model_zip, model_path_in_zip="generic/ms2.pth") + ms2_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ms2.pth") rt_model = AlphaRTModel() - rt_model.load(model_zip, model_path_in_zip="generic/rt.pth") + rt_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/rt.pth") ccs_model = AlphaCCSModel() - ccs_model.load(model_zip, model_path_in_zip="generic/ccs.pth") + ccs_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ccs.pth") return ms2_model, rt_model, ccs_model def load_models_by_model_type_in_zip(model_type_in_zip: str, mask_modloss=True): + _download_models(MODEL_ZIP_FILE_PATH) ms2_model = pDeepModel(mask_modloss=mask_modloss) - ms2_model.load(model_zip, model_path_in_zip=f"{model_type_in_zip}/ms2.pth") + ms2_model.load( + MODEL_ZIP_FILE_PATH, model_path_in_zip=f"{model_type_in_zip}/ms2.pth" + ) rt_model = AlphaRTModel() - rt_model.load(model_zip, model_path_in_zip=f"{model_type_in_zip}/rt.pth") + rt_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip=f"{model_type_in_zip}/rt.pth") ccs_model = AlphaCCSModel() - ccs_model.load(model_zip, model_path_in_zip=f"{model_type_in_zip}/ccs.pth") + ccs_model.load( + MODEL_ZIP_FILE_PATH, model_path_in_zip=f"{model_type_in_zip}/ccs.pth" + ) return ms2_model, rt_model, ccs_model @@ -284,6 +297,8 @@ def __init__( if device=='gpu' but no GPUs are detected, it will automatically switch to 'cpu'. Defaults to 'gpu' """ + _download_models(MODEL_ZIP_FILE_PATH) + self._train_psm_logging = True self.ms2_model: pDeepModel = pDeepModel( @@ -430,9 +445,15 @@ def load_installed_models(self, model_type: str = "generic"): Defaults to 'generic'. """ if model_type.lower() in ["phospho", "phos", "phosphorylation"]: - self.ms2_model.load(model_zip, model_path_in_zip="generic/ms2.pth") - self.rt_model.load(model_zip, model_path_in_zip="phospho/rt_phos.pth") - self.ccs_model.load(model_zip, model_path_in_zip="generic/ccs.pth") + self.ms2_model.load( + MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ms2.pth" + ) + self.rt_model.load( + MODEL_ZIP_FILE_PATH, model_path_in_zip="phospho/rt_phos.pth" + ) + self.ccs_model.load( + MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ccs.pth" + ) elif model_type.lower() in [ "digly", "glygly", @@ -440,13 +461,23 @@ def load_installed_models(self, model_type: str = "generic"): "ubiquitination", "ubiquitinylation", ]: - self.ms2_model.load(model_zip, model_path_in_zip="generic/ms2.pth") - self.rt_model.load(model_zip, model_path_in_zip="digly/rt_digly.pth") - self.ccs_model.load(model_zip, model_path_in_zip="generic/ccs.pth") + self.ms2_model.load( + MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ms2.pth" + ) + self.rt_model.load( + MODEL_ZIP_FILE_PATH, model_path_in_zip="digly/rt_digly.pth" + ) + self.ccs_model.load( + MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ccs.pth" + ) elif model_type.lower() in ["regular", "common", "generic"]: - self.ms2_model.load(model_zip, model_path_in_zip="generic/ms2.pth") - self.rt_model.load(model_zip, model_path_in_zip="generic/rt.pth") - self.ccs_model.load(model_zip, model_path_in_zip="generic/ccs.pth") + self.ms2_model.load( + MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ms2.pth" + ) + self.rt_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/rt.pth") + self.ccs_model.load( + MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ccs.pth" + ) elif model_type.lower() in ["hla", "unspecific", "non-specific", "nonspecific"]: self.load_installed_models(model_type="generic") else: diff --git a/peptdeep/utils/deprecations.py b/peptdeep/utils/deprecations.py new file mode 100644 index 00000000..f8ba3927 --- /dev/null +++ b/peptdeep/utils/deprecations.py @@ -0,0 +1,39 @@ +"""ModuleType to deprecate variables in a module.""" + +from collections import defaultdict +from typing import Any +from warnings import warn +from types import ModuleType + + +class ModuleWithDeprecations(ModuleType): + """ModuleType to deprecate variables in a module.""" + + _deprecations = defaultdict(dict) + + def __getattr__(self, name: str) -> Any: + """Get unknown module attributes, raising a warning if it's deprecated. + + To deprecate a variable: + > import sys, ModuleWithDeprecations + > sys.modules[__name__].__class__ = ModuleWithDeprecations + > ModuleWithDeprecations.deprecate(__name__, "old_name", "new_name") + """ + module_deprecations = self._deprecations[self.__name__] + if name in module_deprecations: + new_name = module_deprecations[name] + msg = f"{name} is deprecated! Use '{new_name}' instead." + warn(msg, DeprecationWarning) + print(f"WARNING: {msg}") + return self.__getattribute__(new_name) + + # to get the standard error message + return object().__getattribute__(name) + + @classmethod + def deprecate(cls, class_name: str, old_name: str, new_name: str) -> None: + """Deprecate `old_name` in favour of `new_name` for `class_name`. + + Pass "__name__" as first argument. + """ + cls._deprecations[class_name][old_name] = new_name