Skip to content

Commit ab5f1e5

Browse files
committed
docstrings,type hints for public items.
1 parent 1564097 commit ab5f1e5

File tree

8 files changed

+710
-130
lines changed

8 files changed

+710
-130
lines changed

.gitignore

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
share/python-wheels/
24+
*.egg-info/
25+
.installed.cfg
26+
*.egg
27+
MANIFEST
28+
29+
# PyInstaller
30+
# Usually these files are written by a python script from a template
31+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
32+
*.manifest
33+
*.spec
34+
35+
# Installer logs
36+
pip-log.txt
37+
pip-delete-this-directory.txt
38+
39+
# Unit test / coverage reports
40+
htmlcov/
41+
.tox/
42+
.nox/
43+
.coverage
44+
.coverage.*
45+
.cache
46+
nosetests.xml
47+
coverage.xml
48+
*.cover
49+
*.py,cover
50+
.hypothesis/
51+
.pytest_cache/
52+
cover/
53+
54+
# Translations
55+
*.mo
56+
*.pot
57+
58+
# Django stuff:
59+
*.log
60+
local_settings.py
61+
db.sqlite3
62+
db.sqlite3-journal
63+
64+
# Flask stuff:
65+
instance/
66+
.webassets-cache
67+
68+
# Scrapy stuff:
69+
.scrapy
70+
71+
# Sphinx documentation
72+
docs/_build/
73+
74+
# PyBuilder
75+
.pybuilder/
76+
target/
77+
78+
# Jupyter Notebook
79+
.ipynb_checkpoints
80+
81+
# IPython
82+
profile_default/
83+
ipython_config.py
84+
85+
# pyenv
86+
# For a library or package, you might want to ignore these files since the code is
87+
# intended to run in multiple environments; otherwise, check them in:
88+
# .python-version
89+
90+
# pipenv
91+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
93+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
94+
# install all needed dependencies.
95+
#Pipfile.lock
96+
97+
# poetry
98+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99+
# This is especially recommended for binary packages to ensure reproducibility, and is more
100+
# commonly ignored for libraries.
101+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102+
#poetry.lock
103+
104+
# pdm
105+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106+
#pdm.lock
107+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108+
# in version control.
109+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110+
.pdm.toml
111+
.pdm-python
112+
.pdm-build/
113+
114+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115+
__pypackages__/
116+
117+
# Celery stuff
118+
celerybeat-schedule
119+
celerybeat.pid
120+
121+
# SageMath parsed files
122+
*.sage.py
123+
124+
# Environments
125+
.env
126+
.venv
127+
env/
128+
venv/
129+
ENV/
130+
env.bak/
131+
venv.bak/
132+
133+
# Spyder project settings
134+
.spyderproject
135+
.spyproject
136+
137+
# Rope project settings
138+
.ropeproject
139+
140+
# mkdocs documentation
141+
/site
142+
143+
# mypy
144+
.mypy_cache/
145+
.dmypy.json
146+
dmypy.json
147+
148+
# Pyre type checker
149+
.pyre/
150+
151+
# pytype static type analyzer
152+
.pytype/
153+
154+
# Cython debug symbols
155+
cython_debug/
156+
157+
# PyCharm
158+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160+
# and can be added to the global gitignore or merged into this file. For a more nuclear
161+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
162+
#.idea/

openvs/args.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
'''Typed arguments defination for argparse type checking and code completion.'''
2+
13
import os,sys
24
from tap import Tap
3-
from typing import Any, Callable, List, Tuple, Union
5+
from typing import Any, Callable, List, Optional, Tuple, Union
46
from typing_extensions import Literal
57

68
class ExtractSmilesArgs(Tap):
@@ -13,19 +15,28 @@ class ExtractSmilesArgs(Tap):
1315
validatefn: str
1416
datarootdir: str
1517

18+
1619
class VanillaModelArgs(Tap):
20+
'''Typed args for Vanilla model.'''
1721
nnodes: int = 3000
22+
'''Neuron nodes number in one layer'''
1823
nBits: int = 1024
24+
'''Length of morgan fingerprint vector.'''
1925
dataset_type: Literal["binaryclass", "multiclass", "regression"]
26+
'''Predict form.'''
2027
dropout: float = 0.5
28+
'''Dropout factor in dropout layer.'''
2129
nlayers: int = 2
30+
'''Number of same layers.'''
31+
2232

2333
class TrainArgs(Tap):
34+
'''Typed args for training mode.'''
2435
modelid: str = "0"
2536
i_iter: int = 1
26-
train_datafn: str = None
27-
test_datafn: str = None
28-
validate_datafn: str = None
37+
train_datafn: Optional[str] = None
38+
test_datafn: Optional[str] = None
39+
validate_datafn: Optional[str] = None
2940
hit_ratio: float = 0.0
3041
score_cutoff: float = 0.0
3142
prefix: str = ""
@@ -35,29 +46,32 @@ class TrainArgs(Tap):
3546
rand_seed: int = 66666
3647
log_frequency: int = 500
3748
weight_class: bool = False
38-
class_weights: List=[1,1,1,1]
49+
class_weights: List[float] = [1, 1, 1, 1]
3950
patience: int = 5
40-
disable_progress_bar : bool = False
51+
disable_progress_bar: bool = False
4152
inferenceDropout: bool = False
42-
53+
4354

4455
class EvalArgs(Tap):
45-
topNs: List = [10, 100, 1000, 10000]
46-
thresholds: List = [0.2, 0.35, 0.5]
47-
target_threshold: float = None
56+
topNs: List[int] = [10, 100, 1000, 10000]
57+
thresholds: List[float] = [0.2, 0.35, 0.5]
58+
target_threshold: Optional[float] = None
4859
target_recall: float = 0.9 #only used in validation set evaluation
4960
rand_active_prob: float
5061
dataset_type: Literal["test", "validate"]
5162
disable_progress_bar : bool = False
5263

64+
5365
class PredictArgs(Tap):
54-
modelfn: str = None
55-
database_type: str = None
56-
database_path: str = None
57-
prediction_path: str = None
66+
'''Typed args for predicting mode.'''
67+
modelfn: Optional[str] = None
68+
database_type: Optional[str] = None
69+
database_path: Optional[str] = None
70+
prediction_path: Optional[str] = None
5871
disable_progress_bar: bool = True
59-
batch_size : int = 10000
72+
'''Whether to disable progresss bar.'''
73+
batch_size: int = 10000
6074
outfileformat: str = "feather"
75+
'''Extension name of the output file.'''
6176
run_platform: str="auto" #Literal["gpu", "slurm", "auto"], I need "auto" to be default
6277
i_iter: int
63-

openvs/models.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
1+
'''Network model implementations to acclerate visual screening.'''
2+
13
import os,sys
24
import torch
35
import torch.nn as nn
46
import torch.nn.functional as F
57
from openvs.args import VanillaModelArgs
68

9+
710
class VanillaNet(nn.Module):
8-
def __init__(self, args: VanillaModelArgs ):
11+
'''A classical one-to-one network.'''
12+
13+
def __init__(self, args: VanillaModelArgs):
914
super().__init__()
1015
nBits = args.nBits
1116
nnodes = args.nnodes
@@ -24,7 +29,7 @@ def __init__(self, args: VanillaModelArgs ):
2429
self.dropout1 = nn.Dropout(dropoutfreq)
2530
self.dropout2 = nn.Dropout(dropoutfreq)
2631
self.out_activation = nn.Sigmoid()
27-
32+
2833
def forward(self, x):
2934
x = F.relu(self.bn1(self.fc1(x)))
3035
x = self.dropout1(x)
@@ -36,8 +41,10 @@ def forward(self, x):
3641
x = self.out_activation(x)
3742
return x
3843

44+
3945
class VanillaNet2(nn.Module):
40-
def __init__(self, args: VanillaModelArgs ):
46+
'''A classical one-to-one network.'''
47+
def __init__(self, args: VanillaModelArgs):
4148
super().__init__()
4249
nBits = args.nBits
4350
nnodes = args.nnodes
@@ -55,7 +62,7 @@ def __init__(self, args: VanillaModelArgs ):
5562
self.bn3 = nn.BatchNorm1d(num_features=nnodes)
5663
self.dropout = nn.Dropout(dropoutfreq)
5764
self.out_activation = nn.Sigmoid()
58-
65+
5966
def forward(self, x):
6067
x = F.relu(self.bn1(self.fc1(x)))
6168
x = self.dropout(x)
@@ -82,13 +89,13 @@ def __init__(self, args: VanillaModelArgs ):
8289
self.fc_in = nn.Linear(nBits, nnodes)
8390
self.fcs = nn.ModuleList([nn.Linear(nnodes, nnodes) for i in range(self.nlayers)] )
8491
self.fc_out = nn.Linear(nnodes, 1)
85-
92+
8693
self.bn1 = nn.BatchNorm1d(num_features=nnodes)
8794
self.bns = nn.ModuleList([nn.BatchNorm1d(num_features=nnodes) for i in range(self.nlayers)])
88-
95+
8996
self.dropout = nn.Dropout(dropoutfreq)
9097
self.out_activation = nn.Sigmoid()
91-
98+
9299
def forward(self, x):
93100
x = F.relu(self.bn1(self.fc_in(x)))
94101
x = self.dropout(x)

openvs/utils/cluster.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,37 @@
1+
'''Clustering algoritms.'''
2+
13
import os,sys
24
import numpy as np
35
import torch
46
from time import time
57

6-
def one_to_all_tanimoto(x, X):
8+
def one_to_all_tanimoto(x, X) -> torch.Tensor:
9+
'''Calculate 1 - tanimoto similarity between vector x and vector set X.
10+
11+
If x and X[:,i] are same,tanimoto[i] is 0;if x and X[:,1] are totally different,tanimoto[i] is 1;otherwise it's between 0~1.
12+
13+
In clustering algoritms,two vectors' `distance` is shorter when they are more similar.
14+
'''
715
c = torch.sum(X*x, dim=1)
816
a = torch.sum(X,dim=1)
917
b = torch.sum(x)
10-
18+
1119
return 1-c.type(torch.float)/(a+b-c).type(torch.float)
12-
1320

14-
def one_to_all_euclidean(x, X, dist_metric="euclidean"):
15-
return torch.sqrt(torch.sum((X-x)**2,dim=1))
21+
22+
def one_to_all_euclidean(x, X, dist_metric="euclidean") -> torch.Tensor:
23+
'''Calculate euclidean distance between vector x and vector set X.'''
24+
return torch.sqrt(torch.sum((X - x)**2, dim=1))
1625

1726

1827
class BestFirstClustering():
19-
def __init__(self, cutoff, dist_metric="tanimoto", dtype=torch.uint8):
28+
def __init__(self, cutoff, dist_metric: str="tanimoto", dtype=torch.uint8):
2029

30+
self.cutoff = cutoff
2131
if dist_metric == "euclidean":
22-
self.cutoff = cutoff
23-
self.one_to_all_d = one_to_all_gpu_euclidean
32+
self.one_to_all_d = one_to_all_euclidean
2433

2534
elif dist_metric == 'tanimoto':
26-
self.cutoff = cutoff
2735
self.one_to_all_d = one_to_all_tanimoto
2836
if torch.cuda.is_available():
2937
self.use_gpu = True

0 commit comments

Comments
 (0)