diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..efa407c --- /dev/null +++ b/.gitignore @@ -0,0 +1,162 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ \ No newline at end of file diff --git a/openvs/args.py b/openvs/args.py index b20bbb2..908b18b 100644 --- a/openvs/args.py +++ b/openvs/args.py @@ -1,6 +1,8 @@ +'''Typed arguments defination for argparse type checking and code completion.''' + import os,sys from tap import Tap -from typing import Any, Callable, List, Tuple, Union +from typing import Any, Callable, List, Optional, Tuple, Union from typing_extensions import Literal class ExtractSmilesArgs(Tap): @@ -13,19 +15,28 @@ class ExtractSmilesArgs(Tap): validatefn: str datarootdir: str + class VanillaModelArgs(Tap): + '''Typed args for Vanilla model.''' nnodes: int = 3000 + '''Neuron nodes number in one layer''' nBits: int = 1024 + '''Length of morgan fingerprint vector.''' dataset_type: Literal["binaryclass", "multiclass", "regression"] + '''Predict form.''' dropout: float = 0.5 + '''Dropout factor in dropout layer.''' nlayers: int = 2 + '''Number of same layers.''' + class TrainArgs(Tap): + '''Typed args for training mode.''' modelid: str = "0" i_iter: int = 1 - train_datafn: str = None - test_datafn: str = None - validate_datafn: str = None + train_datafn: Optional[str] = None + test_datafn: Optional[str] = None + validate_datafn: Optional[str] = None hit_ratio: float = 0.0 score_cutoff: float = 0.0 prefix: str = "" @@ -35,29 +46,32 @@ class TrainArgs(Tap): rand_seed: int = 66666 log_frequency: int = 500 weight_class: bool = False - class_weights: List=[1,1,1,1] + class_weights: List[float] = [1, 1, 1, 1] patience: int = 5 - disable_progress_bar : bool = False + disable_progress_bar: bool = False inferenceDropout: bool = False - + class EvalArgs(Tap): - topNs: List = [10, 100, 1000, 10000] - thresholds: List = [0.2, 0.35, 0.5] - target_threshold: float = None + topNs: List[int] = [10, 100, 1000, 10000] + thresholds: List[float] = [0.2, 0.35, 0.5] + target_threshold: Optional[float] = None target_recall: float = 0.9 #only used in validation set evaluation rand_active_prob: float dataset_type: Literal["test", "validate"] disable_progress_bar : bool = False + class PredictArgs(Tap): - modelfn: str = None - database_type: str = None - database_path: str = None - prediction_path: str = None + '''Typed args for predicting mode.''' + modelfn: Optional[str] = None + database_type: Optional[str] = None + database_path: Optional[str] = None + prediction_path: Optional[str] = None disable_progress_bar: bool = True - batch_size : int = 10000 + '''Whether to disable progresss bar.''' + batch_size: int = 10000 outfileformat: str = "feather" + '''Extension name of the output file.''' run_platform: str="auto" #Literal["gpu", "slurm", "auto"], I need "auto" to be default i_iter: int - diff --git a/openvs/models.py b/openvs/models.py index 7dbf9f9..ecbc198 100644 --- a/openvs/models.py +++ b/openvs/models.py @@ -1,11 +1,16 @@ +'''Network model implementations to acclerate visual screening.''' + import os,sys import torch import torch.nn as nn import torch.nn.functional as F from openvs.args import VanillaModelArgs + class VanillaNet(nn.Module): - def __init__(self, args: VanillaModelArgs ): + '''A classical one-to-one network.''' + + def __init__(self, args: VanillaModelArgs): super().__init__() nBits = args.nBits nnodes = args.nnodes @@ -24,7 +29,7 @@ def __init__(self, args: VanillaModelArgs ): self.dropout1 = nn.Dropout(dropoutfreq) self.dropout2 = nn.Dropout(dropoutfreq) self.out_activation = nn.Sigmoid() - + def forward(self, x): x = F.relu(self.bn1(self.fc1(x))) x = self.dropout1(x) @@ -36,8 +41,10 @@ def forward(self, x): x = self.out_activation(x) return x + class VanillaNet2(nn.Module): - def __init__(self, args: VanillaModelArgs ): + '''A classical one-to-one network.''' + def __init__(self, args: VanillaModelArgs): super().__init__() nBits = args.nBits nnodes = args.nnodes @@ -55,7 +62,7 @@ def __init__(self, args: VanillaModelArgs ): self.bn3 = nn.BatchNorm1d(num_features=nnodes) self.dropout = nn.Dropout(dropoutfreq) self.out_activation = nn.Sigmoid() - + def forward(self, x): x = F.relu(self.bn1(self.fc1(x))) x = self.dropout(x) @@ -82,13 +89,13 @@ def __init__(self, args: VanillaModelArgs ): self.fc_in = nn.Linear(nBits, nnodes) self.fcs = nn.ModuleList([nn.Linear(nnodes, nnodes) for i in range(self.nlayers)] ) self.fc_out = nn.Linear(nnodes, 1) - + self.bn1 = nn.BatchNorm1d(num_features=nnodes) self.bns = nn.ModuleList([nn.BatchNorm1d(num_features=nnodes) for i in range(self.nlayers)]) - + self.dropout = nn.Dropout(dropoutfreq) self.out_activation = nn.Sigmoid() - + def forward(self, x): x = F.relu(self.bn1(self.fc_in(x))) x = self.dropout(x) diff --git a/openvs/utils/cluster.py b/openvs/utils/cluster.py index 54ce55c..3a0f5e7 100644 --- a/openvs/utils/cluster.py +++ b/openvs/utils/cluster.py @@ -1,29 +1,37 @@ +'''Clustering algoritms.''' + import os,sys import numpy as np import torch from time import time -def one_to_all_tanimoto(x, X): +def one_to_all_tanimoto(x, X) -> torch.Tensor: + '''Calculate 1 - tanimoto similarity between vector x and vector set X. + + If x and X[:,i] are same,tanimoto[i] is 0;if x and X[:,1] are totally different,tanimoto[i] is 1;otherwise it's between 0~1. + + In clustering algoritms,two vectors' `distance` is shorter when they are more similar. + ''' c = torch.sum(X*x, dim=1) a = torch.sum(X,dim=1) b = torch.sum(x) - + return 1-c.type(torch.float)/(a+b-c).type(torch.float) - -def one_to_all_euclidean(x, X, dist_metric="euclidean"): - return torch.sqrt(torch.sum((X-x)**2,dim=1)) + +def one_to_all_euclidean(x, X, dist_metric="euclidean") -> torch.Tensor: + '''Calculate euclidean distance between vector x and vector set X.''' + return torch.sqrt(torch.sum((X - x)**2, dim=1)) class BestFirstClustering(): - def __init__(self, cutoff, dist_metric="tanimoto", dtype=torch.uint8): + def __init__(self, cutoff, dist_metric: str="tanimoto", dtype=torch.uint8): + self.cutoff = cutoff if dist_metric == "euclidean": - self.cutoff = cutoff - self.one_to_all_d = one_to_all_gpu_euclidean + self.one_to_all_d = one_to_all_euclidean elif dist_metric == 'tanimoto': - self.cutoff = cutoff self.one_to_all_d = one_to_all_tanimoto if torch.cuda.is_available(): self.use_gpu = True diff --git a/openvs/utils/db_utils.py b/openvs/utils/db_utils.py index 4a709a8..5fdc02b 100644 --- a/openvs/utils/db_utils.py +++ b/openvs/utils/db_utils.py @@ -1,4 +1,7 @@ +'''Filter,extract,compress and convert data files.''' + import os,sys,io +from typing import Any, Hashable import orjson import pandas as pd import tarfile @@ -12,8 +15,20 @@ from distributed import Client, as_completed, wait from dask_jobqueue import SLURMCluster -def dataframe2dict(df, key_column, value_column): - retdict={} +def dataframe2dict(df:pd.DataFrame, key_column:Hashable, value_column:Hashable) -> dict[Hashable,list[Any]]: + '''Convert a DataFrame into a dict. + + Params + ====== + - df:Input DataFrame. + - key_column:One of df's column name. + - value_column:One of df's column name. + + Returns + ======= + A dict.It's keys is items of df.key_column;It's values is items of df.value_column respectively. + ''' + retdict: dict[Hashable, list[Any]]={} if key_column not in df.columns: raise Exception(f"{key_column} is a wrong column name.") if value_column not in df.columns: @@ -23,19 +38,19 @@ def dataframe2dict(df, key_column, value_column): return retdict -def map_raw3dfn_to_zincids_wrapper(inargs): +def map_raw3dfn_to_zincids_wrapper(inargs:tuple[str,str,str,set[str]]) -> dict[str,set[str]]: tid4l, dbdir, indexdir, zincids = inargs return map_raw3dfn_to_zincids(tid4l, dbdir, indexdir, zincids) -def map_raw3dfn_to_zincids_tid2l_wrapper(inargs): +def map_raw3dfn_to_zincids_tid2l_wrapper(inargs:tuple[str,str,str,set[str]]) -> dict[str,set[str]]: tid2l, dbdir, indexdir, zincids = inargs return map_raw3dfn_to_zincids_tid2l(tid2l, dbdir, indexdir, zincids) -def map_raw3dfn_to_zincids(tid4l:str, dbdir:str, indexdir:str, zincids:set ): +def map_raw3dfn_to_zincids(tid4l:str, dbdir:str, indexdir:str, zincids:set[str]) -> dict[str,set[str]]: pattern = os.path.join(indexdir, tid4l[:2], f"{tid4l}*.json") db_indexfns = glob(pattern) - zincids_to_3dfns = {} + zincids_to_3dfns:dict[str,str] = {} for db_indexfn in db_indexfns: with open(db_indexfn, 'rb') as infh: dbindex = orjson.loads(infh.read()) @@ -45,22 +60,21 @@ def map_raw3dfn_to_zincids(tid4l:str, dbdir:str, indexdir:str, zincids:set ): if len(zincids_to_3dfns) == len(zincids): break - mol2fn_to_zincids ={} - for zincid in zincids_to_3dfns: - mol2fn = zincids_to_3dfns[zincid] + mol2fn_to_zincids:dict[str,set[str]] ={} + for mol2fn, zincid in zincids_to_3dfns.items(): if mol2fn != "" and os.path.exists(mol2fn): - mol2fn_to_zincids.setdefault(mol2fn, set([])).update([zincid]) + mol2fn_to_zincids.setdefault(mol2fn, set()).update([zincid]) elif os.path.exists( os.path.join(dbdir, tid4l[:2], mol2fn) ): - mol2fn_to_zincids.setdefault(os.path.join(dbdir, tid4l[:2], mol2fn), set([])).update([zincid]) + mol2fn_to_zincids.setdefault(os.path.join(dbdir, tid4l[:2], mol2fn), set()).update([zincid]) else: - mol2fn_to_zincids.setdefault(f"NA_{tid4l}", set([])).update([zincid]) + mol2fn_to_zincids.setdefault(f"NA_{tid4l}", set()).update([zincid]) return mol2fn_to_zincids -def map_raw3dfn_to_zincids_tid2l(tid2l:str, dbdir:str, indexdir:str, zincids:set ): +def map_raw3dfn_to_zincids_tid2l(tid2l:str, dbdir:str, indexdir:str, zincids:set[str]) -> dict[str,set[str]]: pattern = os.path.join(indexdir, tid2l, f"{tid2l}*.json") db_indexfns = glob(pattern) - zincids_to_3dfns = {} + zincids_to_3dfns:dict[str,str] = {} for db_indexfn in db_indexfns: with open(db_indexfn, 'rb') as infh: dbindex = orjson.loads(infh.read()) @@ -70,25 +84,53 @@ def map_raw3dfn_to_zincids_tid2l(tid2l:str, dbdir:str, indexdir:str, zincids:set if len(zincids_to_3dfns) == len(zincids): break - mol2fn_to_zincids ={} - for zincid in zincids_to_3dfns: - mol2fn = zincids_to_3dfns[zincid] + mol2fn_to_zincids:dict[str,set[str]] ={} + for mol2fn,zincid in zincids_to_3dfns.items(): if mol2fn != "" and os.path.exists(mol2fn): - mol2fn_to_zincids.setdefault(mol2fn, set([])).update([zincid]) + mol2fn_to_zincids.setdefault(mol2fn, set()).update([zincid]) elif os.path.exists( os.path.join(dbdir, tid2l, mol2fn) ): - mol2fn_to_zincids.setdefault(os.path.join(dbdir, tid2l, mol2fn), set([])).update([zincid]) + mol2fn_to_zincids.setdefault(os.path.join(dbdir, tid2l, mol2fn), set()).update([zincid]) else: - mol2fn_to_zincids.setdefault(f"NA_{tid2l}", set([])).update([zincid]) + mol2fn_to_zincids.setdefault(f"NA_{tid2l}", set()).update([zincid]) return mol2fn_to_zincids -def extract_tarmember_to_folder_wrapper(inargs): +def extract_tarmember_to_folder_wrapper(inargs:tuple[set[str],str,str,str,str]) -> int: + '''Wrapper for `extract_tarmember_to_folder`. + + Params + ====== + - inargs: Packed args for `extract_tarmember_to_folder` + - zincids: A set of zinc ids to be extracted. + - intarfn: Input tar filename. + - outdir: Extract directory path. + - extra = "": Extra infix for saving as a different extracted file. + - logpath = "": Directory for log file saving. + + Returns + ======= + Number of successful extracted files. + ''' zincids, intarfn, outdir, extra, logpath = inargs return extract_tarmember_to_folder(zincids, intarfn, outdir, extra, logpath) -def extract_tarmember_to_folder( zincids, intarfn, outdir, extra="", logpath=""): +def extract_tarmember_to_folder( zincids:set[str], intarfn:str, outdir:str, extra:str="", logpath:str="") -> int: + '''Filter and extract . + + Params + ====== + - zincids: A set of zinc ids to be extracted. + - intarfn: Input tar filename. + - outdir: Extract directory path. + - extra = "": Extra infix for saving as a different extracted file. + - logpath = "": Directory for log file saving. + + Returns + ======= + Number of successful extracted files. + ''' n = 0 - extracted=set([]) + extracted:set[str]=set() with tarfile.open(intarfn, 'r') as intarfh: for member in intarfh.getmembers(): zincid = os.path.basename(member.name).split(".")[0] @@ -119,7 +161,7 @@ def extract_tarmember_to_folder( zincids, intarfn, outdir, extra="", logpath="") outfname = outfname + ".failed.txt" outfn = os.path.join(logpath, outfname) - content = [] + content:list[str] = [] for zincid in set(zincids) - extracted: content.append(f"{zincid}\n") with open(outfn, 'w') as outfh: @@ -130,11 +172,38 @@ def extract_tarmember_to_folder( zincids, intarfn, outdir, extra="", logpath="") return n -def extract_tarparams_to_folder_wrapper(inargs): +def extract_tarparams_to_folder_wrapper(inargs:tuple[set[str],str,str,str]) -> int: + '''Wrapper for `extract_tarparams_to_folder`. + + Params + ====== + - inargs: Packed args for `extract_tarparams_to_folder`. + - zincids: A set of zinc ids(in filenames) to be extracted. + - intarfn: Input tar filename. + - outdir: Extract directory path. + - extra = "": Extra infix for saving as a different extracted file. + + Returns + ======= + Number of successful extracted file. + ''' zincids, intarfn, outdir, extra = inargs return extract_tarparams_to_folder(zincids, intarfn, outdir, extra) -def extract_tarparams_to_folder( zincids, intarfn, outdir, extra=""): +def extract_tarparams_to_folder( zincids:set[str], intarfn:str, outdir:str, extra:str=""): + '''Filter and extract `*.params` files in tar file according to their filename. + + Params + ====== + - zincids: A set of zinc ids(in filenames) to be extracted. + - intarfn: Input tar filename. + - outdir: Extract directory path. + - extra = "": Extra infix for saving as a different extracted file. + + Returns + ======= + Number of successful extracted files. + ''' n = 0 with tarfile.open(intarfn, 'r') as intarfh: for member in intarfh.getmembers(): @@ -163,8 +232,20 @@ def extract_tarparams_to_folder( zincids, intarfn, outdir, extra=""): print(f"Successfully extracted {n} members to {outdir}") return n -# add tar members to another tar file -def add_tarmember_to_tarfile(zincids, intarfn, outtarfn, backup=True): +def add_tarmember_to_tarfile(zincids: set[str], intarfn:str, outtarfn:str, backup:bool=True) -> int: + '''Add tar members to another tar file. + + Params + ====== + - zincids: A set of zinc ids(in filenames) to be re-compressed. + - intarfn: Input tar filename. + - outdir: Re-compressed tar filename. + - backup = True: Whether to create a `backups` directory and copy last created tar file into it. + + Returns + ======= + Number of successful re-compressed files. + ''' if os.path.exists(outtarfn): tar_mode = "a" if backup: @@ -195,8 +276,20 @@ def add_tarmember_to_tarfile(zincids, intarfn, outtarfn, backup=True): print(f"Successfully added {n} members to {outtarfn}") return n -# add regular files to tar file -def add_files_to_tarfile(regularfns, outtarfn, backup=True, overwrite=False): +def add_files_to_tarfile(regularfns:set[str], outtarfn:str, backup:bool=True, overwrite:bool=False) -> int: + '''Add regular files into a tar file. + + Params + ====== + - regularfns: A set of regular filenames to be compressed. + - outtarfn: Compressed tar filename.If the tar file exists already,update it;otherwise create a new one. + - backup = True: Whether to create a `backups` directory and copy last created tar file into it. + - overwrite = False: whether to regenerate a tar file when it exists. + + Returns + ======= + Number of successful compressed files. + ''' if os.path.exists(outtarfn): tar_mode = "a" if backup: diff --git a/openvs/utils/params_utils.py b/openvs/utils/params_utils.py index 39c85d6..7362afe 100644 --- a/openvs/utils/params_utils.py +++ b/openvs/utils/params_utils.py @@ -1,6 +1,9 @@ +'''Generate params from specified structure files.''' + from __future__ import print_function import os,sys +from typing import Literal, Optional, Union os.environ["OPENBLAS_NUM_THREADS"] = "1" import subprocess as sp from glob import glob @@ -11,7 +14,12 @@ from .db_utils import add_files_to_tarfile import multiprocessing as mp -def run_cmd(cmd): +def run_cmd(cmd) -> Union[tuple[str,int],Literal[0]]: + '''Run a command in a subprocess. + + Return 0 if success; + + return the command and returncode of the subprocess.''' p = sp.Popen(cmd, shell=True) p.communicate() ret = int(p.returncode) @@ -19,9 +27,36 @@ def run_cmd(cmd): return " ".join(cmd), ret return ret -def gen_params_from_folder(indir, outdir, mode='multiprocessing', - overwrite=False, nopdb=False, mol2gen_app=None, - multimol2=False, infer_atomtypes=False, queue='cpu'): + +def gen_params_from_folder(indir: str, + outdir: str, + mode: Literal['multiprocessing', 'slurm', 'local', + 'debug'] = 'multiprocessing', + overwrite: bool = False, + nopdb: bool = False, + mol2gen_app:Optional[str]=None, + multimol2: bool = False, + infer_atomtypes: bool = False, + queue: str = 'cpu') -> None: + '''Search *.mol2 file in specified directory and generate params files. + + Params + ====== + - indir: A directory contains mol2 files. + - outdir: Generated directory of parameters. + - mode = 'multiprocessing': Can be'multiprocessing', 'slurm', 'local',or 'debug'. + - overwrite = False: Whether to generate parms files when they are generated before. + - nopdb = False:Do not report pdb + - mol2gen_app = None: Path of the script to generate mol2 params files.Search in $ROSETTAHOME if it's None. + - multimol2 = False: If the input mol2 file has multiple structures + - infer_atomtypes = False: Infering the correct atom type for some special cases + with a wrong input atom type + - queue = 'cpu': Destination queue for each worker jpb. + + Returns + ======= + None. + ''' pattern = os.path.join(indir, "*.mol2") mol2fns = sorted(glob(pattern)) print("Number of mol2fns:", len(mol2fns)) @@ -45,7 +80,7 @@ def gen_params_from_folder(indir, outdir, mode='multiprocessing', print(f"ROSETTAHOME: {rosettahome}") mol2gen_app = os.path.join(rosettahome, "source/scripts/python/public/generic_potential/mol2genparams.py") - + if not os.path.exists(outdir): os.makedirs(outdir) print(f"Made dir: {outdir}") @@ -69,7 +104,7 @@ def gen_params_from_folder(indir, outdir, mode='multiprocessing', cmd.append("--multimol2") if infer_atomtypes: cmd.append("--infer_atomtypes") - + cmd = " ".join(cmd) #print(cmd) if mode == 'slurm': @@ -104,7 +139,22 @@ def gen_params_from_folder(indir, outdir, mode='multiprocessing', results = pool.map(run_cmd, joblist) print(results) -def gen_tarparams_from_list(paramsfns, outfn, overwrite=False): + +def gen_tarparams_from_list(paramsfns: set[str], + outfn: str, + overwrite: bool = False) -> None: + '''Compress a set of params files into a tar file. + + Params + ====== + - paramsfns:A set of regular filenames to be compressed. + - outfn: Compressed tar filename.If the tar file exists already,update it;otherwise create a new one. + - overwrite = False: whether to regenerate a tar file when it exists. + + Returns + ======= + None. + ''' if not overwrite and os.path.exists(outfn): print(f"{outfn} exists, skip.") return @@ -112,8 +162,19 @@ def gen_tarparams_from_list(paramsfns, outfn, overwrite=False): n = add_files_to_tarfile(paramsfns, outfn, backup=False, overwrite=True) -def ligandlist_from_tarparamsfn(intarfn, outfn): - ligandlines = [] +def ligandlist_from_tarparamsfn(intarfn: str, outfn: str) -> None: + '''Create a file.It contains ligand ids parsed from a tar file. + + Params + ====== + - intarfn:A tar file.Compressed file in it may contains ligand ids. + - outfn: The output ligand list file. + + Returns + ======= + None. + ''' + ligandlines:list[str] = [] with tarfile.open(intarfn, 'r') as intar: raw_members = intar.getmembers() for member in raw_members: diff --git a/openvs/utils/rosetta_utils.py b/openvs/utils/rosetta_utils.py index de238ff..51e88a2 100644 --- a/openvs/utils/rosetta_utils.py +++ b/openvs/utils/rosetta_utils.py @@ -1,5 +1,9 @@ +'''Parse rosetta log files.''' + from __future__ import print_function +from io import BytesIO, TextIOWrapper import os,sys +from typing import Iterable, Literal, Optional, Union import numpy as np import pandas as pd import subprocess @@ -26,10 +30,20 @@ def _check_tags_ndx(fn,tags, line_marker = "nomarker", tags_marker = "descriptio break return valid_tags,indices -def _check_tags_ndx_line(line, tags): - - indices = [] - valid_tags = [] +def _check_tags_ndx_line(line:str, tags:Iterable[str]) -> tuple[list[str], list[int]]: + '''Read and split a line;find out tags' indices in it. + + Params + ====== + - line: A string in a stream. + - tags: A group of tagnames. + + Returns + ======= + A list of vaild tag and a list of their indices. + ''' + indices:list[int] = [] + valid_tags:list[str] = [] fields = line.strip().split() for tag in tags: @@ -45,8 +59,21 @@ def _check_tags_ndx_line(line, tags): return valid_tags,indices -def _line_parser(line,indices=None,dtypes=(float,) ): - values = [] +def _line_parser(line:str, indices:Optional[list[int]] = None,dtypes:tuple[type]=(float,) ) -> Union[list[object],Literal[False],str]: + '''Parse fields in one line. + + Params + ====== + - line: A line in a stream. + - indices = None: Indices to locate in the split line.If `None`, return the original line. + - dtypes = (float,): A group of date types to convert str into.When length of it is 1,all fields are parsed as the same date type in dtypes. + + Returns + ======= + - False: When the line says the its result fails; or some field parse fail in the line. + - A list of parsed value. + ''' + values:list[object] = [] if indices is None: return line else: @@ -57,7 +84,7 @@ def _line_parser(line,indices=None,dtypes=(float,) ): if len(dtypes) == 1: dtype = dtypes[0] else: - dtype = dtypes[indices.index(i)] + dtype = dtypes[indices.index(i)] # FIXME:is it right? try: if 'nan' in fields[i]: return False values.append(dtype(fields[i])) @@ -67,24 +94,50 @@ def _line_parser(line,indices=None,dtypes=(float,) ): fields[i] except IndexError: continue - + raise Exception("Cannot convert %s to %s"%(fields[i],dtype)) - - + + if len(values) != len(indices): return False return values -def read_log_fstream(infh, tags, dtypes, line_marker = "SCORE:" , tags_marker = "description", header_line=None, ignore_incomplete=False, verbose=False): - valid_tags = [] - values = [] +def read_log_fstream(infh: Union[TextIOWrapper, BytesIO], + tags: list[str], + dtypes: tuple[type], + line_marker: str = "SCORE:", + tags_marker: str = "description", + header_line: Optional[str] = None, + ignore_incomplete: bool = False, + verbose: int = False) -> tuple[list[str], list[Union[list[object],str]]]: + '''Read log file stream and parse the lines into values. + + Params + ====== + - infh: A file handle or any str/bytes stream. + - tags: A list of tag names. + - dtypes: A group of date types,corresponding to tags. + - line_marker = 'SCORE:': Parse the lines if they start with line_marker.Parse all lines when line_marker is `nomarker`. + - tags_marker = 'description': Lines containing it is parsed as header line. + - header_line = None: A string contains tags.Supposed to be the header line of a file.If None,use `tags_marker` to find header line. + - ignore_incomplete = False: Whether to skip incomplete parsed lines. + - verbose = False: Log output level.Output more when it is higher. + + Returns + ======= + 1. A list of vaild tag names, + 2. and a list of parsed value. + ''' + valid_tags:list[str] = [] + values:list[Union[list[object],str]] = [] + indices = [] if header_line is not None: valid_tags, indices = _check_tags_ndx_line(header_line, tags) assert len(valid_tags) == len(tags), "Valid tags: %s doesn't match specified tags %s!"%(valid_tags,tags) if valid_tags[-1] == "description": indices[-1] = -1 - + for l in infh: try: line = l.decode() @@ -94,7 +147,7 @@ def read_log_fstream(infh, tags, dtypes, line_marker = "SCORE:" , tags_marker = pass elif (not line.startswith(line_marker)): continue - + if header_line is None and tags_marker in line: valid_tags, indices = _check_tags_ndx_line(line, tags) assert len(valid_tags) == len(tags), f"Valid tags: {valid_tags} doesn't match specified tags {tags}!" @@ -118,8 +171,25 @@ def read_log_fstream(infh, tags, dtypes, line_marker = "SCORE:" , tags_marker = return valid_tags, values -def read_log(fn, tags, dtypes, line_marker = "SCORE:" , tags_marker = "description", header_line=None, ignore_incomplete=False, verbose=False): - +def read_log(fn:str, tags:list[str], dtypes:tuple[type], line_marker:str = "SCORE:" , tags_marker:str = "description", header_line:Optional[str]=None, ignore_incomplete: bool = False, verbose: bool = False) -> tuple[list[str], list[Union[list[object],str]]]: + '''Open a log file and parse the lines into values. + + Params + ====== + - fn: Filename of log. + - tags: A list of tag names. + - dtypes: A group of date types,corresponding to tags. + - line_marker = 'SCORE:': Parse the lines if they start with line_marker.Parse all lines when line_marker is `nomarker`. + - tags_marker = 'description': Lines containing it is parsed as header line. + - header_line = None: A string contains tags.Supposed to be the header line of a file.If None,use `tags_marker` to find header line. + - ignore_incomplete = False: Whether to skip incomplete parsed lines. + - verbose = False: Log output level.Output more when it is higher. + + Returns + ======= + 1. A list of vaild tag names, + 2. and a list of parsed value. + ''' if not os.path.exists(fn): raise IOError("%s doesn't exist!"%fn) @@ -131,8 +201,20 @@ def read_log(fn, tags, dtypes, line_marker = "SCORE:" , tags_marker = "descripti return read_log_fstream(infn, tags, dtypes, line_marker, tags_marker, header_line, ignore_incomplete, verbose) -def read_score_line(fn,tags, verbose=False): - +def read_score_line(fn:str,tags:list[str], verbose:bool=False) -> tuple[list[str], list[Union[list[object],str]]]: + '''Read score lines of a log file and parse into values. + + Params + ====== + - fn: Filename of log. + - tags: A list of tag names. + - verbose = False: Log output level.Output more when it is higher. + + Returns + ======= + 1. A list of vaild tag names, + 2. and a list of parsed value. + ''' if not os.path.exists(fn): raise IOError("%s doesn't exist!"%fn) @@ -142,4 +224,3 @@ def read_score_line(fn,tags, verbose=False): verbose=verbose) return valid_tags, values - diff --git a/openvs/utils/utils.py b/openvs/utils/utils.py index db629ce..f337fb6 100644 --- a/openvs/utils/utils.py +++ b/openvs/utils/utils.py @@ -1,3 +1,5 @@ +'''Usefual functions for smi file handling,fingerprint calculation and statistical analysis.''' + import os,sys from time import time # import scientific py @@ -9,7 +11,8 @@ from rdkit import Chem from rdkit.Chem import AllChem -from typing import Any, Callable, List, Tuple, Union +from collections.abc import Iterable +from typing import Any, Callable, List, Sequence, Tuple, Union from typing_extensions import Literal import orjson @@ -20,60 +23,137 @@ from sklearn.metrics import auc, mean_absolute_error, mean_squared_error, precision_recall_curve, r2_score,\ roc_auc_score, accuracy_score, log_loss, precision_score, recall_score -def load_smiles_file(input_file, delimiter=" "): +def load_smiles_file(input_file: str, delimiter: str = " ") -> pd.DataFrame: + '''Load smiles infomation from csv-like smi file. + + Params + ====== + + - input_file: Path of input smi file. + - delimiter = ' ': Delimiter for csv-like file. + + Returns + ======= + + DataFrame including ss,ids and other descriptions. + ''' data = pd.read_csv(input_file, delimiter=delimiter) return data -def smi2fp_bitstring_helper(smi, morgan_radius =2, nBits=1024): +def smi2fp_bitstring_helper(smi: str, morgan_radius: int =2, nBits: int = 1024) -> str: + '''Convert smiles string into morgan fingerprint(0/1 format string). + + Params + ====== + - smi: A smiles string. + - morgan_radius = 2: Morgan radius of substructure to calculate fingerprints,need to be int. + - nBits = 1024: Length of the fingerprint vector. + + Returns + ======= + Binary format string of the fingerprint. + ''' mol = Chem.MolFromSmiles(smi) fp = AllChem.GetMorganFingerprintAsBitVect(mol, morgan_radius, nBits=nBits) fp_bitstring = fp.ToBitString() return fp_bitstring -def smi2fp_hexstring_helper(smi, morgan_radius =2, nBits=1024): +def smi2fp_hexstring_helper(smi:str, morgan_radius:int =2, nBits:int=1024) -> bytes: + '''Convert smile string into morgan fingerprint's bytes(0x hexadecimal format). + + Params + ====== + - smi: A smile string. + - morgan_radius = 2: Morgan radius of substructure to calculate fingerprints,need to be int. + - nBits = 1024: Length of the fingerprint vector. + + Returns + ======= + bytes of the fingerprint of the smile string. + ''' mol = Chem.MolFromSmiles(smi) fp = AllChem.GetMorganFingerprintAsBitVect(mol, morgan_radius, nBits=nBits) fp_bitstring = fp.ToBitString() fp_hexstring = format(int(fp_bitstring, 2), 'x') return fp_hexstring -def smiles_to_bitstrings(smiles_list, morgan_radius =2, nBits=1024): - """ - morgan_radius = 2 is roughly equivalent to ECFP 4 +def smiles_to_bitstrings(smiles_list:Iterable[str], morgan_radius: int = 2, nBits: int = 1024) -> tuple[str]: + """Convert a list of smiles strings into bitstrings in parallel. + + Params + ====== + - smiles_list: A list of smiles strings. + - morgan_radius = 2: Morgan radius of substructure,2 is roughly equivalent to ECFP 4 + - nBits = 1024: Length of the fingerprint vector. + + Returns + ======= + + A tuple of calculated morgan fingerprings(0/1-strings). """ bitstrings = [] jobs = [] for i, smi in enumerate(smiles_list): jobs.append(delayed(smi2fp_bitstring_helper)(smi, morgan_radius, nBits)) - + bitstrings = compute(*jobs, scheduler="processes") return bitstrings -def smiles_to_hexstrings(smiles_list, morgan_radius =2, nBits=1024): - """ - morgan_radius = 2 is roughly equivalent to ECFP 4 +def smiles_to_hexstrings(smiles_list:Iterable[str], morgan_radius:int =2, nBits:int=1024) -> tuple[bytes]: + """Convert a list of smiles strings into hex bytes in parallel. + + Params + ====== + - smiles_list: A list of smiles strings. + - morgan_radius = 2: Morgan radius of substructure,2 is roughly equivalent to ECFP 4 + - nBits = 1024: Length of the fingerprint vector. + + Returns + ======= + + A tuple of calculated morgan fingerprings(hex-bytes). """ bitstrings = [] jobs = [] for i, smi in enumerate(smiles_list): jobs.append(delayed(smi2fp_hexstring_helper)(smi, morgan_radius, nBits)) - + bitstrings = compute(*jobs, scheduler="processes") return bitstrings -def smiles_to_hexstrings_slow(smiles_list, morgan_radius =2, nBits=1024): - """ - morgan_radius = 2 is roughly equivalent to ECFP 4 +def smiles_to_hexstrings_slow(smiles_list:Iterable[str], morgan_radius:int =2, nBits:int=1024) -> list[bytes]: + """Convert a list of smiles strings into hex bytes in order. + + Params + ====== + - smiles_list: A list of smiles strings. + - morgan_radius = 2: Morgan radius of substructure,2 is roughly equivalent to ECFP 4 + - nBits = 1024: Length of the fingerprint vector. + + Returns + ======= + A list of calculated morgan fingerprings(hex-bytes). """ - bitstrings = [] + bitstrings:list[bytes] = [] for i, smi in enumerate(smiles_list): bitstrings.append( smi2fp_hexstring_helper(smi, morgan_radius, nBits) ) return bitstrings -def smiles_to_bitarrays(smiles_list, radius=2, nBits=1024, useFeature=True, useChirality=True): - """ - morgan_radius = 2 is roughly equivalent to ECFP 4 +def smiles_to_bitarrays(smiles_list: Iterable[str], radius:int=2, nBits:int=1024, useFeature:bool=True, useChirality:bool=True) -> np.ndarray: + """Convert a list of smiles strings into an int matrix. + + Params + ====== + - smiles_list: A group of smiles strings. + - radius = 2: Morgan radius of substructure,2 is roughly equivalent to ECFP 4. + - nBits = 1024: Length of the fingerprint vector. + - useFeature = True: If False,use ConnectivityMorgan; if True: use FeatureMorgan. + - useChirality = True: If True, chirality information will be included as a part of the bond invariants. + + Returns + ======= + An int numpy matrix(n_smi*nBits),each column is a morgen 0/1 vector fingerprint. """ bitarrays = np.zeros((len(smiles_list), nBits), dtype=int) for i, smi in enumerate(smiles_list): @@ -84,7 +164,7 @@ def smiles_to_bitarrays(smiles_list, radius=2, nBits=1024, useFeature=True, useC return bitarrays -def smiles_to_bitarrays_slow(smiles_list, morgan_radius =2, nBits=1024): +def smiles_to_bitarrays_slow(smiles_list:Iterable[str], morgan_radius:int =2, nBits:int=1024): """ morgan_radius = 2 is roughly equivalent to ECFP 4 """ @@ -102,25 +182,63 @@ def smiles_to_bitarrays_slow(smiles_list, morgan_radius =2, nBits=1024): return fps_bitarray -def smiles_to_binary_fingerprints(smis, radius=2, nBits=1024, useFeature=True, useChirality=True): - fps = [] +def smiles_to_binary_fingerprints(smis:Iterable[str], radius:int=2, nBits:int=1024, useFeature:bool=True, useChirality:bool=True) -> list[bytes]: + """Convert a list of smiles strings into an list of hex-format morgan fingerprints. + + Params + ====== + - smiles_list: A group of smiles strings. + - radius = 2: Morgan radius of substructure,2 is roughly equivalent to ECFP 4. + - nBits = 1024: Length of the fingerprint vector. + - useFeature = True: If False,use ConnectivityMorgan; if True: use FeatureMorgan. + - useChirality = True: If True, chirality information will be included as a part of the bond invariants. + + Returns + ======= + An list of morgan fingerprints,each item is a morgen hex-format fingerprint. + """ + fps:list[bytes] = [] for i, smi in enumerate(smis): m = Chem.MolFromSmiles(smi) fp = AllChem.GetMorganFingerprintAsBitVect(m, radius=radius, nBits=nBits, useFeatures=useFeature, useChirality=useChirality) fps.append(fp.ToBinary()) return fps -def get_accuracys(targets: List[int], preds: Union[List[float], List[List[float]]], thresholds: float = [0.2, 0.35, 0.5] ) -> float: +def get_accuracys(targets: Sequence[int], preds: Union[List[float], List[List[float]]], thresholds: Iterable[float] = (0.2, 0.35, 0.5) ) -> np.ndarray: + """Calculate accuracy according to a group of thresholds. + + Params + ====== + - targets: 0/1 sample values + - preds: Float prediction vector(s).If a prediction > threshold,it will be 1;otherwise 0. + - thresholds: A group of thresholds for calculating accuracy respectively. + + Returns + ======= + A numpy n_threshold * n_preds ndarray,each line is accuracy score between targets and predictions. + """ acc = [] for threshold in thresholds: hard_preds = [1 if p > threshold else 0 for p in preds] # binary prediction acc.append(accuracy_score(targets, hard_preds)) return np.array(acc) -def get_precisions(targets: List[int], - preds: Union[List[float], - List[List[float]]], - thresholds: List[float]=[0.2, 0.35, 0.5]) -> List[float]: +def get_precisions(targets: Sequence[int], + preds: Union[List[float], + List[List[float]]], + thresholds: Iterable[float]=(0.2, 0.35, 0.5)) -> np.ndarray: + """Calculate precision according to a group of thresholds. + + Params + ====== + - targets: 0/1 sample values + - preds: Float prediction vector(s).If a prediction > threshold,it will be 1;otherwise 0. + - thresholds: A group of thresholds for calculating precision respectively. + + Returns + ======= + A numpy n_threshold * n_preds ndarray,each line is precision score between targets and predictions. + """ precisions = [] for threshold in thresholds: hard_preds = [1 if p > threshold else 0 for p in preds] @@ -137,9 +255,21 @@ def get_FPDE(ys, ys_pred, threshold, random_p): fpde = TP_topN/TP_randomN return fpde -def get_recalls(targets: List[int], - preds: Union[List[float], List[List[float]]], - thresholds: List[float]=[0.2, 0.35, 0.5]) -> List[float]: +def get_recalls(targets: Sequence[int], + preds: Union[List[float], List[List[float]]], + thresholds: Iterable[float]=(0.2, 0.35, 0.5)) -> np.ndarray: + """Calculate recalls according to a group of thresholds. + + Params + ====== + - targets: 0/1 sample values + - preds: Float prediction vector(s).If a prediction > threshold,it will be 1;otherwise 0. + - thresholds: A group of thresholds for calculating recalls respectively. + + Returns + ======= + A numpy n_threshold * n_preds ndarray,each line is recall score between targets and predictions. + """ recalls = [] for threshold in thresholds: hard_preds = [1 if p > threshold else 0 for p in preds] @@ -154,7 +284,21 @@ def get_data_kept_percentage(preds, thresholds): retval.append(np.sum(preds>=t)/ntotal) return retval -def get_enrichment_factors(random_p, topNs, preds, truth): +def get_enrichment_factors(random_p: float, topNs: Iterable[int], preds: Sequence[float] , truth: Sequence[float]) -> np.ndarray: + """Calculate enrichment factors of top N. + + Params + ====== + - random_p: + - topNs: A group of integers to calculate topN enrichment factors. + - preds: Float prediction vector. + - truth: Sample vector. + - thresholds: A group of thresholds for calculating recalls respectively. + + Returns + ======= + A numpy n_topNs ndarray,each line is enrichment factor between predictions and samples. + """ I = np.argsort(preds)[::-1] #descending order sorted_truth = np.array(truth)[I] EFs = [] @@ -181,11 +325,11 @@ def get_enrichment_factors2(random_p, topNs, preds, truth): thresholds.append(sorted_pres[N]) return np.array(EFs), np.array(thresholds) -def recall_threshold_detector(target_recall: float, - targets: List[int], preds: Union[List[float], +def recall_threshold_detector(target_recall: float, + targets: List[int], preds: Union[List[float], List[List[float]]], eps: float = 1E-6) -> float: def biserch(target_value, xl, xr, vl=None, vr=None): - + if vl is None: vl = get_recalls(targets, preds, [xl] )[0] if vr is None: @@ -207,7 +351,7 @@ def biserch(target_value, xl, xr, vl=None, vr=None): else: return biserch(target_value, xm, xr, vm, vr) - + threshold, recall = biserch(target_recall, 0.0, 1.0) recall2 = get_recalls(targets, preds, [threshold])[0] print(threshold, recall, recall2) @@ -218,10 +362,21 @@ def get_top_std(pred_mean, pred_std, N=100): return np.array(pred_std)[I][:N] def get_bottom_std(pred_mean, pred_std, N=100): - I = np.argsort(pred_mean) #descending order + I = np.argsort(pred_mean) #ascending order return np.array(pred_std)[I][:N] -def load_configfn(configfn): + +def load_configfn(configfn: str) -> dict: + '''Load a binary config file and convert into a config dict. + + Params + ====== + configfn: Config file name. + + Returns + ======= + A config dict. + ''' with open(configfn, 'rb') as f: config = orjson.loads(f.read()) return config