Skip to content

Commit f64ab2e

Browse files
Adjust benchmarking code to changed interfaces (#515)
1 parent e26f390 commit f64ab2e

File tree

3 files changed

+14
-22
lines changed

3 files changed

+14
-22
lines changed

benchmarking/benchmark.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,13 @@
66
import matplotlib as mpl
77
import matplotlib.pyplot as plt
88
import numpy as np
9-
import pandas as pd
109
import skfp.fingerprints as fps
1110
from joblib import cpu_count
12-
from ogb.graphproppred import GraphPropPredDataset
11+
from skfp.datasets.moleculenet import load_hiv
1312
from skfp.preprocessing import ConformerGenerator, MolFromSmilesTransformer
1413

1514
mpl.rcParams.update({"font.size": 18})
1615

17-
DATASET_NAME = "ogbg-molhiv"
18-
1916
# N_SPLITS - number of parts in which the dataset will be divided.
2017
# the test is performed first on 1 of them, then 2, ... then N_SPLITS
2118
# testing different sizes of input data
@@ -160,17 +157,13 @@ def make_combined_plot(
160157
if not os.path.exists(SCORE_DIR):
161158
os.makedirs(SCORE_DIR)
162159

163-
GraphPropPredDataset(name=DATASET_NAME, root=os.path.join("..", "dataset"))
164-
dataset_path = os.path.join(
165-
"..", "dataset", "_".join(DATASET_NAME.split("-")), "mapping", "mol.csv.gz"
166-
)
167-
dataset = pd.read_csv(dataset_path)
160+
dataset = load_hiv()
168161

169162
if os.path.exists("mols_with_conformers.npy"):
170163
X = np.load("mols_with_conformers.npy", allow_pickle=True)
171164
else:
172165
X = dataset["smiles"][:10000]
173-
X = MolFromSmilesTransformer().transform(X)
166+
X = MolFromSmilesTransformer(valid_only=True).transform(X)
174167
X = ConformerGenerator(n_jobs=-1, errors="filter").transform(X)
175168
X = np.array(X)
176169
np.save("mols_with_conformers.npy", X, allow_pickle=True)

benchmarking/fp_tuning.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import numpy as np
22
import skfp.fingerprints as fps
3-
from ogb.graphproppred import GraphPropPredDataset
43
from rdkit.Chem import Mol
54
from skfp.bases import BaseFingerprintTransformer
6-
from skfp.datasets.moleculenet import load_moleculenet_benchmark
5+
from skfp.datasets.moleculenet import load_moleculenet_benchmark, load_ogb_splits
76
from skfp.preprocessing import MolFromSmilesTransformer
87
from skfp.utils import no_rdkit_logs
98
from sklearn.ensemble import RandomForestClassifier
@@ -42,7 +41,7 @@ def fp_name_to_fp(fp_name: str) -> tuple[BaseFingerprintTransformer, dict]:
4241
fingerprint = fps.EStateFingerprint(n_jobs=-1)
4342
fp_params_grid = {"variant": ["sum", "bit", "count"]}
4443
elif fp_name == "FCFP":
45-
fingerprint = fps.ECFPFingerprint(use_fcfp=True, n_jobs=-1)
44+
fingerprint = fps.ECFPFingerprint(use_pharmacophoric_invariants=True, n_jobs=-1)
4645
fp_params_grid = {
4746
"fp_size": [1024, 2048, 4096],
4847
"radius": [2, 3],
@@ -78,7 +77,7 @@ def fp_name_to_fp(fp_name: str) -> tuple[BaseFingerprintTransformer, dict]:
7877
fp_params_grid = {
7978
"fp_size": [512, 1024, 2048],
8079
"radius": [2, 3],
81-
"variant": ["bit", "count"],
80+
"count": [False, True],
8281
}
8382
elif fp_name == "Pattern":
8483
fingerprint = fps.PatternFingerprint()
@@ -157,13 +156,9 @@ def train_and_tune_fp_classifier(
157156
print("DATASET", dataset_name)
158157
X = np.array(X)
159158

160-
dataset = GraphPropPredDataset(
161-
name=f"ogbg-mol{dataset_name.lower()}", root=".tmp"
162-
)
163-
split_idx = dataset.get_idx_split()
159+
train_idxs, valid_idxs, test_idxs = load_ogb_splits(dataset_name)
164160

165-
train_idxs = list(split_idx["train"]) + list(split_idx["valid"])
166-
test_idxs = list(split_idx["test"])
161+
train_idxs = list(train_idxs) + list(valid_idxs)
167162

168163
smiles_train = X[train_idxs]
169164
smiles_test = X[test_idxs]
@@ -206,6 +201,7 @@ def train_and_tune_fp_classifier(
206201
fp=fp,
207202
fp_params_grid=fp_params_grid,
208203
)
204+
209205
print(
210206
f"AUROC default {auroc_default:.1%}, tuned {auroc_tuned:.1%}, diff: {diff:.1%}"
211207
)

skfp/bases/base_fp_transformer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -271,8 +271,11 @@ def _hash_fingerprint_bits(
271271
)
272272

273273
shape = (len(X), fp_size)
274-
dtype = np.uint32 if count else np.uint8
275-
arr = dok_array(shape, dtype=dtype) if sparse else np.zeros(shape, dtype=dtype)
274+
arr = (
275+
dok_array(shape, dtype=np.uint32)
276+
if sparse
277+
else np.zeros(shape, dtype=np.uint32)
278+
)
276279

277280
if isinstance(X[0], SparseBitVect):
278281
for idx, x in enumerate(X):

0 commit comments

Comments
 (0)