Skip to content

Commit 3ab691f

Browse files
committed
add fingerprint dataset, logistic regression model
1 parent 016b5ea commit 3ab691f

File tree

4 files changed

+110
-1
lines changed

4 files changed

+110
-1
lines changed

chebai/models/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@ def _get_prediction_and_labels(
106106
Returns:
107107
Tuple[torch.Tensor, torch.Tensor]: Predictions and labels.
108108
"""
109-
return output, labels
109+
# cast labels to int
110+
return output, labels.to(torch.int) if labels is not None else labels
110111

111112
def _process_labels_in_batch(self, batch: XYData) -> torch.Tensor:
112113
"""

chebai/models/classic_ml.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
from typing import Any, Dict
2+
3+
import torch
4+
import tqdm
5+
from sklearn.exceptions import NotFittedError
6+
from sklearn.linear_model import LogisticRegression as SklearnLogisticRegression
7+
8+
from chebai.models.base import ChebaiBaseNet
9+
10+
11+
class LogisticRegression(ChebaiBaseNet):
12+
"""
13+
Logistic Regression model using scikit-learn, wrapped to fit the ChebaiBaseNet interface.
14+
"""
15+
16+
def __init__(self, out_dim: int, input_dim: int, **kwargs):
17+
super().__init__(out_dim=out_dim, input_dim=input_dim, **kwargs)
18+
self.models = [SklearnLogisticRegression(solver="liblinear") for _ in range(5)]
19+
20+
def forward(self, x: Dict[str, Any], **kwargs) -> torch.Tensor:
21+
print(
22+
f"forward called with x[features].shape {x['features'].shape}, self.training {self.training}"
23+
)
24+
if self.training:
25+
self.fit_sklearn(x["features"], x["labels"])
26+
try:
27+
preds = [
28+
torch.from_numpy(model.predict(x["features"]))
29+
.to(
30+
x["features"].device
31+
if isinstance(x["features"], torch.Tensor)
32+
else "cpu"
33+
)
34+
.float()
35+
for model in self.models
36+
]
37+
except NotFittedError:
38+
# Not fitted yet, return zeros
39+
print(
40+
f"returning default 0s with shape {(x['features'].shape[0], self.out_dim)}"
41+
)
42+
return torch.zeros(
43+
(x["features"].shape[0], self.out_dim),
44+
device=(
45+
x["features"].device
46+
if isinstance(x["features"], torch.Tensor)
47+
else "cpu"
48+
),
49+
)
50+
preds = torch.stack(preds, dim=1)
51+
print(f"preds shape {preds.shape}")
52+
return preds
53+
54+
def fit_sklearn(self, X, y):
55+
"""
56+
Fit the underlying sklearn model. X and y should be numpy arrays.
57+
"""
58+
for i, model in tqdm.tqdm(enumerate(self.models), desc="Fitting models"):
59+
model.fit(X, y[:, i])
60+
61+
def configure_optimizers(self, **kwargs):
62+
pass

chebai/preprocessing/datasets/chebi.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -809,6 +809,22 @@ class ChEBIOver50Partial(ChEBIOverXPartial, ChEBIOver50):
809809
pass
810810

811811

812+
class ChEBIOverXFingerprints(ChEBIOverX):
813+
"""A class that uses Fingerprints for the processed data (used for fixed-length ML models)."""
814+
815+
READER = dr.FingerprintReader
816+
817+
818+
class ChEBIOver100Fingerprints(ChEBIOverXFingerprints, ChEBIOver100):
819+
"""
820+
A class for extracting data from the ChEBI dataset with Fingerprints reader and a threshold of 100.
821+
822+
Inherits from ChEBIOverXFingerprints and ChEBIOver100.
823+
"""
824+
825+
pass
826+
827+
812828
class JCIExtendedBPEData(JCIExtendedBase):
813829
READER = dr.ChemBPEReader
814830

chebai/preprocessing/reader.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,3 +372,33 @@ def name(cls) -> str:
372372
def _read_data(self, raw_data: str) -> List[int]:
373373
"""Convert characters in raw data to their ordinal values."""
374374
return [ord(s) for s in raw_data]
375+
376+
377+
class FingerprintReader(DataReader):
378+
"""
379+
Data reader for chemical data using RDKit fingerprints.
380+
381+
Args:
382+
collator_kwargs: Optional dictionary of keyword arguments for the collator.
383+
kwargs: Additional keyword arguments.
384+
"""
385+
386+
COLLATOR = DefaultCollator
387+
388+
def __init__(self, fingerprint_size=1024, *args, **kwargs):
389+
super().__init__(*args, **kwargs)
390+
self.fingerprint_size = fingerprint_size
391+
392+
@classmethod
393+
def name(cls) -> str:
394+
"""Returns the name of the data reader."""
395+
return "rdkit_fingerprint"
396+
397+
def _read_data(self, raw_data: str) -> List[int]:
398+
"""Generate RDKit fingerprint from raw SMILES data."""
399+
mol = Chem.MolFromSmiles(raw_data.strip())
400+
if mol is None:
401+
raise ValueError(f"Invalid SMILES: {raw_data}")
402+
return list(
403+
Chem.RDKFingerprint(mol, fpSize=self.fingerprint_size).ToBitString()
404+
)

0 commit comments

Comments
 (0)