Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions nbs_tests/model/model_interface.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@
"outputs": [],
"source": [
"from peptdeep.model.ms2 import pDeepModel\n",
"from peptdeep.pretrained_models import model_zip"
"from peptdeep.pretrained_models import MODEL_ZIP_FILE_PATH"
]
},
{
Expand Down Expand Up @@ -662,7 +662,7 @@
"source": [
"ms2_model = pDeepModel()\n",
"ms2_model.build_from_py_codes(\n",
" model_zip, 'generic/ms2.pth.model.py', \n",
" MODEL_ZIP_FILE_PATH, 'generic/ms2.pth.model.py', \n",
" include_model_params_yaml=True\n",
")\n",
"\n",
Expand Down
10 changes: 6 additions & 4 deletions nbs_tests/pretrained_models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@
"outputs": [],
"source": [
"#| hide\n",
"assert is_model_zip(model_zip)"
"download_models()\n",
"assert is_model_zip(MODEL_ZIP_FILE_PATH)"
]
},
{
Expand Down Expand Up @@ -98,8 +99,8 @@
"outputs": [],
"source": [
"#| hide\n",
"assert os.path.isfile(model_zip)\n",
"with ZipFile(model_zip) as _zip:\n",
"assert os.path.isfile(MODEL_ZIP_FILE_PATH)\n",
"with ZipFile(MODEL_ZIP_FILE_PATH) as _zip:\n",
" with _zip.open('generic/ms2.pth'):\n",
" pass\n",
" with _zip.open('generic/rt.pth'):\n",
Expand All @@ -119,7 +120,8 @@
"outputs": [],
"source": [
"#| hide\n",
"from io import StringIO"
"from io import StringIO\n",
"import torch"
]
},
{
Expand Down
4 changes: 2 additions & 2 deletions peptdeep/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,10 @@ def _gui(port, settings_yaml):
help="If overwrite existing model file.",
)
def _install_model(model_file, overwrite):
from peptdeep.pretrained_models import download_models, model_url
from peptdeep.pretrained_models import download_models, MODEL_URL

if not model_file:
download_models(model_url, overwrite=overwrite)
download_models(MODEL_URL, overwrite=overwrite)
else:
download_models(model_file, overwrite=overwrite)

Expand Down
4 changes: 2 additions & 2 deletions peptdeep/hla/hla_class1.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import peptdeep.model.building_block as building_block
from peptdeep.model.model_interface import ModelInterface, append_nAA_column_if_missing
from peptdeep.model.featurize import get_ascii_indices
from peptdeep.pretrained_models import pretrain_dir, download_models, global_settings
from peptdeep.pretrained_models import PRETRAIN_DIR, download_models, global_settings

from .hla_utils import (
get_random_sequences,
Expand Down Expand Up @@ -134,7 +134,7 @@ class HLA1_Binding_Classifier(ModelInterface):

_model_zip_name = global_settings["local_hla_model_zip_name"]
_model_url = global_settings["hla_model_url"]
_model_zip = os.path.join(pretrain_dir, _model_zip_name)
_model_zip = os.path.join(PRETRAIN_DIR, _model_zip_name)

def __init__(
self,
Expand Down
103 changes: 67 additions & 36 deletions peptdeep/pretrained_models.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,17 @@
import os
import pathlib
import io
import sys
import pandas as pd
import torch
import urllib
import socket
import logging
import shutil
import ssl
import typing
from pickle import UnpicklingError
import torch.multiprocessing as mp

from peptdeep.utils.deprecations import ModuleWithDeprecations

if sys.platform.lower().startswith("linux"):
# to prevent `too many open files` bug on Linux
mp.set_sharing_strategy("file_system")
Expand All @@ -28,43 +27,44 @@
from alphabase.peptide.precursor import refine_precursor_df, update_precursor_mz
from alphabase.peptide.mobility import mobility_to_ccs_for_df, ccs_to_mobility_for_df

from peptdeep.settings import global_settings, add_user_defined_modifications
from peptdeep.utils import logging, process_bar
from peptdeep.settings import global_settings

from peptdeep.model.ms2 import (
pDeepModel,
normalize_fragment_intensities,
calc_ms2_similarity,
)
from peptdeep.model.rt import AlphaRTModel
from peptdeep.model.ccs import AlphaCCSModel
from peptdeep.model.charge import ChargeModelForAASeq, ChargeModelForModAASeq
from peptdeep.utils import uniform_sampling, evaluate_linear_regression
from peptdeep.utils import uniform_sampling

from peptdeep.settings import global_settings, update_global_settings

pretrain_dir = os.path.join(
PRETRAIN_DIR = os.path.join(
os.path.join(
os.path.expanduser(global_settings["PEPTDEEP_HOME"]), "pretrained_models"
)
)

if not os.path.exists(pretrain_dir):
os.makedirs(pretrain_dir)

model_zip_name = global_settings["local_model_zip_name"]
model_url = global_settings["model_url"]
sys.modules[__name__].__class__ = ModuleWithDeprecations

LOCAL_MODAL_ZIP_NAME = global_settings["local_model_zip_name"]
MODEL_URL = global_settings["model_url"]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be an all caps constant if it's actually loaded from a config?
Why not using global_settings["model_url"] directly?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well, it's a string and somehow constant (within the lifetime of a peptdeep instance) .. that's why I chose ALL_CAPS

MODEL_ZIP_FILE_PATH = os.path.join(PRETRAIN_DIR, LOCAL_MODAL_ZIP_NAME)

model_zip = os.path.join(pretrain_dir, model_zip_name)
ModuleWithDeprecations.deprecate(__name__, "pretrain_dir", "PRETRAIN_DIR")
ModuleWithDeprecations.deprecate(__name__, "model_zip_name", "LOCAL_MODAL_ZIP_NAME")
ModuleWithDeprecations.deprecate(__name__, "model_url", "MODEL_URL")
ModuleWithDeprecations.deprecate(__name__, "model_zip", "MODEL_ZIP_FILE_PATH")


def is_model_zip(downloaded_zip):
with ZipFile(downloaded_zip) as zip:
return any(x == "generic/ms2.pth" for x in zip.namelist())


def download_models(url: str = model_url, target_path: str = model_zip, overwrite=True):
def download_models(url: str = MODEL_URL, target_path: str = MODEL_ZIP_FILE_PATH):
"""
Parameters
----------
Expand All @@ -74,7 +74,7 @@ def download_models(url: str = model_url, target_path: str = model_zip, overwrit

target_path : str, optional
Target file path after download.
Defaults to :data:`peptdeep.pretrained_models.model_zip`
Defaults to :data:`peptdeep.pretrained_models.MODEL_ZIP_FILE_PATH`

overwrite : bool, optional
overwirte old model files.
Expand All @@ -97,16 +97,22 @@ def download_models(url: str = model_url, target_path: str = model_zip, overwrit
"Downloading model failed! Please download the "
f'zip or tar file by yourself from "{url}",'
" and use \n"
f'"peptdeep --install-model /path/to/{model_zip_name}.zip"\n'
f'"peptdeep --install-model /path/to/{LOCAL_MODAL_ZIP_NAME}.zip"\n'
" to install the models"
)
else:
shutil.copy(url, target_path)
logging.info(f"The pretrained models had been downloaded in {target_path}")


if not os.path.exists(model_zip):
download_models()
Comment on lines -108 to -109
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

def _download_models(model_zip_file_path: str) -> None:
"""Download models if not done yet."""
os.makedirs(PRETRAIN_DIR, exist_ok=True)
if not os.path.exists(model_zip_file_path):
download_models()
if not is_model_zip(model_zip_file_path):
raise ValueError(f"Local model file is not a valid zip: {model_zip_file_path}")


model_mgr_settings = global_settings["model_mgr"]

Expand Down Expand Up @@ -185,32 +191,39 @@ def _sample(psm_df, n):


def load_phos_models(mask_modloss=True):
_download_models(MODEL_ZIP_FILE_PATH)
ms2_model = pDeepModel(mask_modloss=mask_modloss)
ms2_model.load(model_zip, model_path_in_zip="phospho/ms2_phos.pth")
ms2_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="phospho/ms2_phos.pth")
rt_model = AlphaRTModel()
rt_model.load(model_zip, model_path_in_zip="phospho/rt_phos.pth")
rt_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="phospho/rt_phos.pth")
ccs_model = AlphaCCSModel()
ccs_model.load(model_zip, model_path_in_zip="generic/ccs.pth")
ccs_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ccs.pth")
return ms2_model, rt_model, ccs_model


def load_models(mask_modloss=True):
_download_models(MODEL_ZIP_FILE_PATH)
ms2_model = pDeepModel(mask_modloss=mask_modloss)
ms2_model.load(model_zip, model_path_in_zip="generic/ms2.pth")
ms2_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ms2.pth")
rt_model = AlphaRTModel()
rt_model.load(model_zip, model_path_in_zip="generic/rt.pth")
rt_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/rt.pth")
ccs_model = AlphaCCSModel()
ccs_model.load(model_zip, model_path_in_zip="generic/ccs.pth")
ccs_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ccs.pth")
return ms2_model, rt_model, ccs_model


def load_models_by_model_type_in_zip(model_type_in_zip: str, mask_modloss=True):
_download_models(MODEL_ZIP_FILE_PATH)
ms2_model = pDeepModel(mask_modloss=mask_modloss)
ms2_model.load(model_zip, model_path_in_zip=f"{model_type_in_zip}/ms2.pth")
ms2_model.load(
MODEL_ZIP_FILE_PATH, model_path_in_zip=f"{model_type_in_zip}/ms2.pth"
)
rt_model = AlphaRTModel()
rt_model.load(model_zip, model_path_in_zip=f"{model_type_in_zip}/rt.pth")
rt_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip=f"{model_type_in_zip}/rt.pth")
ccs_model = AlphaCCSModel()
ccs_model.load(model_zip, model_path_in_zip=f"{model_type_in_zip}/ccs.pth")
ccs_model.load(
MODEL_ZIP_FILE_PATH, model_path_in_zip=f"{model_type_in_zip}/ccs.pth"
)
return ms2_model, rt_model, ccs_model


Expand Down Expand Up @@ -284,6 +297,8 @@ def __init__(
if device=='gpu' but no GPUs are detected, it will automatically switch to 'cpu'.
Defaults to 'gpu'
"""
_download_models(MODEL_ZIP_FILE_PATH)

self._train_psm_logging = True

self.ms2_model: pDeepModel = pDeepModel(
Expand Down Expand Up @@ -430,23 +445,39 @@ def load_installed_models(self, model_type: str = "generic"):
Defaults to 'generic'.
"""
if model_type.lower() in ["phospho", "phos", "phosphorylation"]:
self.ms2_model.load(model_zip, model_path_in_zip="generic/ms2.pth")
self.rt_model.load(model_zip, model_path_in_zip="phospho/rt_phos.pth")
self.ccs_model.load(model_zip, model_path_in_zip="generic/ccs.pth")
self.ms2_model.load(
MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ms2.pth"
)
self.rt_model.load(
MODEL_ZIP_FILE_PATH, model_path_in_zip="phospho/rt_phos.pth"
)
self.ccs_model.load(
MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ccs.pth"
)
elif model_type.lower() in [
"digly",
"glygly",
"ubiquitylation",
"ubiquitination",
"ubiquitinylation",
]:
self.ms2_model.load(model_zip, model_path_in_zip="generic/ms2.pth")
self.rt_model.load(model_zip, model_path_in_zip="digly/rt_digly.pth")
self.ccs_model.load(model_zip, model_path_in_zip="generic/ccs.pth")
self.ms2_model.load(
MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ms2.pth"
)
self.rt_model.load(
MODEL_ZIP_FILE_PATH, model_path_in_zip="digly/rt_digly.pth"
)
self.ccs_model.load(
MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ccs.pth"
)
elif model_type.lower() in ["regular", "common", "generic"]:
self.ms2_model.load(model_zip, model_path_in_zip="generic/ms2.pth")
self.rt_model.load(model_zip, model_path_in_zip="generic/rt.pth")
self.ccs_model.load(model_zip, model_path_in_zip="generic/ccs.pth")
self.ms2_model.load(
MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ms2.pth"
)
self.rt_model.load(MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/rt.pth")
self.ccs_model.load(
MODEL_ZIP_FILE_PATH, model_path_in_zip="generic/ccs.pth"
)
elif model_type.lower() in ["hla", "unspecific", "non-specific", "nonspecific"]:
self.load_installed_models(model_type="generic")
else:
Expand Down
39 changes: 39 additions & 0 deletions peptdeep/utils/deprecations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
"""ModuleType to deprecate variables in a module."""

from collections import defaultdict
from typing import Any
from warnings import warn
from types import ModuleType


class ModuleWithDeprecations(ModuleType):
"""ModuleType to deprecate variables in a module."""

_deprecations = defaultdict(dict)

def __getattr__(self, name: str) -> Any:
"""Get unknown module attributes, raising a warning if it's deprecated.

To deprecate a variable:
> import sys, ModuleWithDeprecations
> sys.modules[__name__].__class__ = ModuleWithDeprecations
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This example confuses me a lot.

Does it need to be sys or is this just and example?
What is modules[__name__]?
Can you provide a simpler example?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe this deprecation mechanism is also too much?

> ModuleWithDeprecations.deprecate(__name__, "old_name", "new_name")
"""
module_deprecations = self._deprecations[self.__name__]
if name in module_deprecations:
new_name = module_deprecations[name]
msg = f"{name} is deprecated! Use '{new_name}' instead."
warn(msg, DeprecationWarning)
print(f"WARNING: {msg}")
return self.__getattribute__(new_name)

# to get the standard error message
return object().__getattribute__(name)

@classmethod
def deprecate(cls, class_name: str, old_name: str, new_name: str) -> None:
"""Deprecate `old_name` in favour of `new_name` for `class_name`.

Pass "__name__" as first argument.
"""
cls._deprecations[class_name][old_name] = new_name
Loading