-
Notifications
You must be signed in to change notification settings - Fork 25
Lazy load pretrained models #229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e9e3486
c20a036
450b6ca
b9aefa5
2d1476d
57fcad6
ea6a5bb
297f8d2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,18 +1,17 @@ | ||
| import os | ||
| import pathlib | ||
| import io | ||
| import sys | ||
| import pandas as pd | ||
| import torch | ||
| import urllib | ||
| import socket | ||
| import logging | ||
| import shutil | ||
| import ssl | ||
| import typing | ||
| from pickle import UnpicklingError | ||
| import torch.multiprocessing as mp | ||
|
|
||
| from peptdeep.utils.deprecations import ModuleWithDeprecations | ||
|
|
||
| if sys.platform.lower().startswith("linux"): | ||
| # to prevent `too many open files` bug on Linux | ||
| mp.set_sharing_strategy("file_system") | ||
|
|
@@ -28,43 +27,44 @@ | |
| from alphabase.peptide.precursor import refine_precursor_df, update_precursor_mz | ||
| from alphabase.peptide.mobility import mobility_to_ccs_for_df, ccs_to_mobility_for_df | ||
|
|
||
| from peptdeep.settings import global_settings, add_user_defined_modifications | ||
| from peptdeep.utils import logging, process_bar | ||
| from peptdeep.settings import global_settings | ||
|
|
||
| from peptdeep.model.ms2 import ( | ||
| pDeepModel, | ||
| normalize_fragment_intensities, | ||
| calc_ms2_similarity, | ||
| ) | ||
| from peptdeep.model.rt import AlphaRTModel | ||
| from peptdeep.model.ccs import AlphaCCSModel | ||
| from peptdeep.model.charge import ChargeModelForAASeq, ChargeModelForModAASeq | ||
| from peptdeep.utils import uniform_sampling, evaluate_linear_regression | ||
| from peptdeep.utils import uniform_sampling | ||
|
|
||
| from peptdeep.settings import global_settings, update_global_settings | ||
|
|
||
| pretrain_dir = os.path.join( | ||
| PRETRAIN_DIR = os.path.join( | ||
| os.path.join( | ||
| os.path.expanduser(global_settings["PEPTDEEP_HOME"]), "pretrained_models" | ||
| ) | ||
| ) | ||
|
|
||
| if not os.path.exists(pretrain_dir): | ||
| os.makedirs(pretrain_dir) | ||
|
|
||
| model_zip_name = global_settings["local_model_zip_name"] | ||
| model_url = global_settings["model_url"] | ||
| sys.modules[__name__].__class__ = ModuleWithDeprecations | ||
|
|
||
| LOCAL_MODAL_ZIP_NAME = global_settings["local_model_zip_name"] | ||
| MODEL_URL = global_settings["model_url"] | ||
| MODEL_ZIP_FILE_PATH = os.path.join(PRETRAIN_DIR, LOCAL_MODAL_ZIP_NAME) | ||
|
|
||
| model_zip = os.path.join(pretrain_dir, model_zip_name) | ||
| ModuleWithDeprecations.deprecate(__name__, "pretrain_dir", "PRETRAIN_DIR") | ||
| ModuleWithDeprecations.deprecate(__name__, "model_zip_name", "LOCAL_MODAL_ZIP_NAME") | ||
| ModuleWithDeprecations.deprecate(__name__, "model_url", "MODEL_URL") | ||
| ModuleWithDeprecations.deprecate(__name__, "model_zip", "MODEL_ZIP_FILE_PATH") | ||
|
|
||
|
|
||
| def is_model_zip(downloaded_zip): | ||
| with ZipFile(downloaded_zip) as zip: | ||
| return any(x == "generic/ms2.pth" for x in zip.namelist()) | ||
|
|
||
|
|
||
| def download_models(url: str = model_url, target_path: str = model_zip, overwrite=True): | ||
| def download_models(url: str = MODEL_URL, target_path: str = MODEL_ZIP_FILE_PATH): | ||
| """ | ||
| Parameters | ||
| ---------- | ||
|
|
@@ -74,7 +74,7 @@ def download_models(url: str = model_url, target_path: str = model_zip, overwrit | |
|
|
||
| target_path : str, optional | ||
| Target file path after download. | ||
| Defaults to :data:`peptdeep.pretrained_models.model_zip` | ||
| Defaults to :data:`peptdeep.pretrained_models.MODEL_ZIP_FILE_PATH` | ||
|
|
||
| overwrite : bool, optional | ||
| overwirte old model files. | ||
|
|
@@ -97,16 +97,22 @@ def download_models(url: str = model_url, target_path: str = model_zip, overwrit | |
| "Downloading model failed! Please download the " | ||
| f'zip or tar file by yourself from "{url}",' | ||
| " and use \n" | ||
| f'"peptdeep --install-model /path/to/{model_zip_name}.zip"\n' | ||
| f'"peptdeep --install-model /path/to/{LOCAL_MODAL_ZIP_NAME}.zip"\n' | ||
| " to install the models" | ||
| ) | ||
| else: | ||
| shutil.copy(url, target_path) | ||
| logging.info(f"The pretrained models had been downloaded in {target_path}") | ||
|
|
||
|
|
||
| if not os.path.exists(model_zip): | ||
| download_models() | ||
|
Comment on lines
-108
to
-109
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is where the downloading on import time was done .. |
||
| def _download_models(model_zip_file_path: str) -> None: | ||
| """Download models if not done yet.""" | ||
| os.makedirs(PRETRAIN_DIR, exist_ok=True) | ||
| if not os.path.exists(model_zip_file_path): | ||
| download_models() | ||
| if not is_model_zip(model_zip_file_path): | ||
| raise ValueError(f"Local model file is not a valid zip: {model_zip_file_path}") | ||
|
|
||
|
|
||
| model_mgr_settings = global_settings["model_mgr"] | ||
|
|
||
|
|
@@ -185,32 +191,39 @@ def _sample(psm_df, n): | |
|
|
||
|
|
||
| def load_phos_models(mask_modloss=True): | ||
| _download_models(MODEL_ZIP_FILE_PATH) | ||
| ms2_model = pDeepModel(mask_modloss=mask_modloss) | ||
| ms2_model.load(model_zip, model_path_in_zip="phospho/ms2_phos.pth") | ||
| ms2_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="phospho/ms2_phos.pth") | ||
| rt_model = AlphaRTModel() | ||
| rt_model.load(model_zip, model_path_in_zip="phospho/rt_phos.pth") | ||
| rt_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="phospho/rt_phos.pth") | ||
| ccs_model = AlphaCCSModel() | ||
| ccs_model.load(model_zip, model_path_in_zip="generic/ccs.pth") | ||
| ccs_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ccs.pth") | ||
| return ms2_model, rt_model, ccs_model | ||
|
|
||
|
|
||
| def load_models(mask_modloss=True): | ||
| _download_models(MODEL_ZIP_FILE_PATH) | ||
| ms2_model = pDeepModel(mask_modloss=mask_modloss) | ||
| ms2_model.load(model_zip, model_path_in_zip="generic/ms2.pth") | ||
| ms2_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ms2.pth") | ||
| rt_model = AlphaRTModel() | ||
| rt_model.load(model_zip, model_path_in_zip="generic/rt.pth") | ||
| rt_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/rt.pth") | ||
| ccs_model = AlphaCCSModel() | ||
| ccs_model.load(model_zip, model_path_in_zip="generic/ccs.pth") | ||
| ccs_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ccs.pth") | ||
| return ms2_model, rt_model, ccs_model | ||
|
|
||
|
|
||
| def load_models_by_model_type_in_zip(model_type_in_zip: str, mask_modloss=True): | ||
| _download_models(MODEL_ZIP_FILE_PATH) | ||
| ms2_model = pDeepModel(mask_modloss=mask_modloss) | ||
| ms2_model.load(model_zip, model_path_in_zip=f"{model_type_in_zip}/ms2.pth") | ||
| ms2_model.load( | ||
| MODEL_ZIP_FILE_PATH, model_path_in_zip=f"{model_type_in_zip}/ms2.pth" | ||
| ) | ||
| rt_model = AlphaRTModel() | ||
| rt_model.load(model_zip, model_path_in_zip=f"{model_type_in_zip}/rt.pth") | ||
| rt_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip=f"{model_type_in_zip}/rt.pth") | ||
| ccs_model = AlphaCCSModel() | ||
| ccs_model.load(model_zip, model_path_in_zip=f"{model_type_in_zip}/ccs.pth") | ||
| ccs_model.load( | ||
| MODEL_ZIP_FILE_PATH, model_path_in_zip=f"{model_type_in_zip}/ccs.pth" | ||
| ) | ||
| return ms2_model, rt_model, ccs_model | ||
|
|
||
|
|
||
|
|
@@ -284,6 +297,8 @@ def __init__( | |
| if device=='gpu' but no GPUs are detected, it will automatically switch to 'cpu'. | ||
| Defaults to 'gpu' | ||
| """ | ||
| _download_models(MODEL_ZIP_FILE_PATH) | ||
|
|
||
| self._train_psm_logging = True | ||
|
|
||
| self.ms2_model: pDeepModel = pDeepModel( | ||
|
|
@@ -430,23 +445,39 @@ def load_installed_models(self, model_type: str = "generic"): | |
| Defaults to 'generic'. | ||
| """ | ||
| if model_type.lower() in ["phospho", "phos", "phosphorylation"]: | ||
| self.ms2_model.load(model_zip, model_path_in_zip="generic/ms2.pth") | ||
| self.rt_model.load(model_zip, model_path_in_zip="phospho/rt_phos.pth") | ||
| self.ccs_model.load(model_zip, model_path_in_zip="generic/ccs.pth") | ||
| self.ms2_model.load( | ||
| MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ms2.pth" | ||
| ) | ||
| self.rt_model.load( | ||
| MODEL_ZIP_FILE_PATH, model_path_in_zip="phospho/rt_phos.pth" | ||
| ) | ||
| self.ccs_model.load( | ||
| MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ccs.pth" | ||
| ) | ||
| elif model_type.lower() in [ | ||
| "digly", | ||
| "glygly", | ||
| "ubiquitylation", | ||
| "ubiquitination", | ||
| "ubiquitinylation", | ||
| ]: | ||
| self.ms2_model.load(model_zip, model_path_in_zip="generic/ms2.pth") | ||
| self.rt_model.load(model_zip, model_path_in_zip="digly/rt_digly.pth") | ||
| self.ccs_model.load(model_zip, model_path_in_zip="generic/ccs.pth") | ||
| self.ms2_model.load( | ||
| MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ms2.pth" | ||
| ) | ||
| self.rt_model.load( | ||
| MODEL_ZIP_FILE_PATH, model_path_in_zip="digly/rt_digly.pth" | ||
| ) | ||
| self.ccs_model.load( | ||
| MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ccs.pth" | ||
| ) | ||
| elif model_type.lower() in ["regular", "common", "generic"]: | ||
| self.ms2_model.load(model_zip, model_path_in_zip="generic/ms2.pth") | ||
| self.rt_model.load(model_zip, model_path_in_zip="generic/rt.pth") | ||
| self.ccs_model.load(model_zip, model_path_in_zip="generic/ccs.pth") | ||
| self.ms2_model.load( | ||
| MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ms2.pth" | ||
| ) | ||
| self.rt_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/rt.pth") | ||
| self.ccs_model.load( | ||
| MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ccs.pth" | ||
| ) | ||
| elif model_type.lower() in ["hla", "unspecific", "non-specific", "nonspecific"]: | ||
| self.load_installed_models(model_type="generic") | ||
| else: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| """ModuleType to deprecate variables in a module.""" | ||
|
|
||
| from collections import defaultdict | ||
| from typing import Any | ||
| from warnings import warn | ||
| from types import ModuleType | ||
|
|
||
|
|
||
| class ModuleWithDeprecations(ModuleType): | ||
| """ModuleType to deprecate variables in a module.""" | ||
|
|
||
| _deprecations = defaultdict(dict) | ||
|
|
||
| def __getattr__(self, name: str) -> Any: | ||
| """Get unknown module attributes, raising a warning if it's deprecated. | ||
|
|
||
| To deprecate a variable: | ||
| > import sys, ModuleWithDeprecations | ||
| > sys.modules[__name__].__class__ = ModuleWithDeprecations | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This example confuses me a lot. Does it need to be sys or is this just and example?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is literally what needs to be done: https://github.com/MannLabs/alphapeptdeep/pull/229/files#diff-c01c08ac727782806e8ddbb06d6e9b41f98c7505b81de9366a03581826c264a6R50
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe this deprecation mechanism is also too much? |
||
| > ModuleWithDeprecations.deprecate(__name__, "old_name", "new_name") | ||
| """ | ||
| module_deprecations = self._deprecations[self.__name__] | ||
| if name in module_deprecations: | ||
| new_name = module_deprecations[name] | ||
| msg = f"{name} is deprecated! Use '{new_name}' instead." | ||
| warn(msg, DeprecationWarning) | ||
| print(f"WARNING: {msg}") | ||
| return self.__getattribute__(new_name) | ||
|
|
||
| # to get the standard error message | ||
| return object().__getattribute__(name) | ||
|
|
||
| @classmethod | ||
| def deprecate(cls, class_name: str, old_name: str, new_name: str) -> None: | ||
| """Deprecate `old_name` in favour of `new_name` for `class_name`. | ||
|
|
||
| Pass "__name__" as first argument. | ||
| """ | ||
| cls._deprecations[class_name][old_name] = new_name | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be an all caps constant if it's actually loaded from a config?
Why not using
global_settings["model_url"]directly?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well, it's a string and somehow constant (within the lifetime of a peptdeep instance) .. that's why I chose ALL_CAPS