Skip to content

Commit 6f2c3ab

Browse files
committed
Add ensemble for electra models, majority voting
1 parent 7ea61d0 commit 6f2c3ab

File tree

9 files changed

+367
-0
lines changed

9 files changed

+367
-0
lines changed

README.md

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,86 @@
11
# python-chebifier
22
An AI ensemble model for predicting chemical classes.
3+
4+
## Installation
5+
6+
```bash
7+
# Clone the repository
8+
git clone https://github.com/yourusername/python-chebifier.git
9+
cd python-chebifier
10+
11+
# Install the package
12+
pip install -e .
13+
```
14+
15+
## Usage
16+
17+
### Command Line Interface
18+
19+
The package provides a command-line interface (CLI) for making predictions using an ensemble model.
20+
21+
```bash
22+
# Get help
23+
python -m chebifier.cli --help
24+
25+
# Make predictions using a configuration file
26+
python -m chebifier.cli predict example_config.yml --smiles "CC(=O)OC1=CC=CC=C1C(=O)O" "C1=CC=C(C=C1)C(=O)O"
27+
28+
# Make predictions using SMILES from a file
29+
python -m chebifier.cli predict example_config.yml --smiles-file smiles.txt
30+
```
31+
32+
### Configuration File
33+
34+
The CLI requires a YAML configuration file that defines the ensemble model. Here's an example:
35+
36+
```yaml
37+
# Example configuration file for Chebifier ensemble model
38+
39+
# Each key in the top-level dictionary is a model name
40+
model1:
41+
# Required: type of model (must be one of the keys in MODEL_TYPES)
42+
type: electra
43+
# Required: name of the model
44+
model_name: electra_model1
45+
# Required: path to the checkpoint file
46+
ckpt_path: /path/to/checkpoint1.ckpt
47+
# Required: path to the target labels file
48+
target_labels_path: /path/to/target_labels1.txt
49+
# Optional: batch size for predictions (default is likely defined in the model)
50+
batch_size: 32
51+
52+
model2:
53+
type: electra
54+
model_name: electra_model2
55+
ckpt_path: /path/to/checkpoint2.ckpt
56+
target_labels_path: /path/to/target_labels2.txt
57+
batch_size: 64
58+
```
59+
60+
### Python API
61+
62+
You can also use the package programmatically:
63+
64+
```python
65+
from chebifier.ensemble.base_ensemble import BaseEnsemble
66+
import yaml
67+
68+
# Load configuration from YAML file
69+
with open('configs/example_config.yml', 'r') as f:
70+
config = yaml.safe_load(f)
71+
72+
# Instantiate ensemble model
73+
ensemble = BaseEnsemble(config)
74+
75+
# Make predictions
76+
smiles_list = ["CC(=O)OC1=CC=CC=C1C(=O)O", "C1=CC=C(C=C1)C(=O)O"]
77+
predictions = ensemble.predict_smiles_list(smiles_list)
78+
79+
# Print results
80+
for smile, prediction in zip(smiles_list, predictions):
81+
print(f"SMILES: {smile}")
82+
if prediction:
83+
print(f"Predicted classes: {prediction}")
84+
else:
85+
print("No predictions")
86+
```

chebifier/__init__.py

Whitespace-only changes.

chebifier/cli.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
2+
3+
4+
import click
5+
import yaml
6+
import sys
7+
from chebifier.ensemble.base_ensemble import BaseEnsemble
8+
9+
10+
@click.group()
11+
def cli():
12+
"""Command line interface for Chebifier."""
13+
pass
14+
15+
16+
@cli.command()
17+
@click.argument('config_file', type=click.Path(exists=True))
18+
@click.option('--smiles', '-s', multiple=True, help='SMILES strings to predict')
19+
@click.option('--smiles-file', '-f', type=click.Path(exists=True), help='File containing SMILES strings (one per line)')
20+
@click.option('--output', '-o', type=click.Path(), help='Output file to save predictions (optional)')
21+
def predict(config_file, smiles, smiles_file, output):
22+
"""Predict ChEBI classes for SMILES strings using an ensemble model.
23+
24+
CONFIG_FILE is the path to a YAML configuration file for the ensemble model.
25+
"""
26+
# Load configuration from YAML file
27+
with open(config_file, 'r') as f:
28+
config = yaml.safe_load(f)
29+
30+
# Instantiate ensemble model
31+
ensemble = BaseEnsemble(config)
32+
33+
# Collect SMILES strings from arguments and/or file
34+
smiles_list = list(smiles)
35+
if smiles_file:
36+
with open(smiles_file, 'r') as f:
37+
smiles_list.extend([line.strip() for line in f if line.strip()])
38+
39+
if not smiles_list:
40+
click.echo("No SMILES strings provided. Use --smiles or --smiles-file options.")
41+
return
42+
43+
# Make predictions
44+
predictions = ensemble.predict_smiles_list(smiles_list)
45+
46+
if output:
47+
# save as json
48+
import json
49+
with open(output, 'w') as f:
50+
json.dump({smiles: pred for smiles, pred in zip(smiles_list, predictions)}, f, indent=2)
51+
52+
else:
53+
# Print results
54+
for i, (smiles, prediction) in enumerate(zip(smiles_list, predictions)):
55+
click.echo(f"Result for: {smiles}")
56+
if prediction:
57+
click.echo(f" Predicted classes: {', '.join(map(str, prediction))}")
58+
else:
59+
click.echo(" No predictions")
60+
61+
62+
if __name__ == '__main__':
63+
cli()

chebifier/ensemble/__init__.py

Whitespace-only changes.
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
from abc import ABC
2+
import torch
3+
import tqdm
4+
from rdkit import Chem
5+
6+
from chebifier.prediction_models.base_predictor import BasePredictor
7+
from chebifier.prediction_models.electra_predictor import ElectraPredictor
8+
9+
MODEL_TYPES = {
10+
"electra": ElectraPredictor,
11+
# todo add other model types here
12+
}
13+
14+
class BaseEnsemble(ABC):
15+
16+
def __init__(self, model_configs: dict):
17+
self.models = []
18+
for model_name, model_config in model_configs.items():
19+
model_cls = MODEL_TYPES[model_config["type"]]
20+
model_instance = model_cls(**model_config)
21+
assert isinstance(model_instance, BasePredictor)
22+
self.models.append(model_instance)
23+
24+
def gather_predictions(self, smiles_list):
25+
"""
26+
27+
:param smiles_list: list of SMILES strings to predict
28+
:return:
29+
ordered_predictions: torch.Tensor of shape (num_smiles, num_classes, num_models)
30+
predicted_classes: list of ChEBI IDs predicted by the models
31+
"""
32+
model_predictions = []
33+
predicted_classes = set()
34+
for model in self.models:
35+
model_predictions.append(model.predict_smiles_list(smiles_list))
36+
for predicted_smiles in model_predictions[-1]:
37+
if predicted_smiles is not None:
38+
for cls in predicted_smiles:
39+
predicted_classes.add(cls)
40+
print(f"Sorting predictions...")
41+
predicted_classes = sorted(list(predicted_classes))
42+
ordered_predictions = torch.zeros(len(smiles_list), len(predicted_classes), len(self.models)) * torch.nan
43+
for i, model_prediction in enumerate(model_predictions):
44+
for j, predicted_smiles in tqdm.tqdm(enumerate(model_prediction),
45+
total=len(model_prediction),
46+
desc=f"Sorting predictions for {self.models[i].model_name}"):
47+
if predicted_smiles is not None:
48+
for cls in predicted_smiles:
49+
ordered_predictions[j, predicted_classes.index(cls), i] = predicted_smiles[cls]
50+
return ordered_predictions, predicted_classes
51+
52+
53+
def aggregate_predictions(self, predictions, predicted_classes, **kwargs):
54+
"""
55+
Aggregates predictions from multiple models using majority voting.
56+
57+
:param predictions: torch.Tensor of shape (num_smiles, num_classes, num_models)
58+
:param predicted_classes: list of ChEBI IDs predicted by the models
59+
:param kwargs: Additional arguments
60+
:return: list of lists, where each inner list contains the class IDs that received
61+
positive predictions from the majority of models for a given SMILES
62+
"""
63+
num_smiles, num_classes, num_models = predictions.shape
64+
result = []
65+
66+
for i in tqdm.tqdm(range(num_smiles), total=num_smiles, desc="Aggregating predictions"):
67+
smiles_result = []
68+
for j in range(num_classes):
69+
# Get predictions for this SMILES and class across all models
70+
class_predictions = predictions[i, j, :]
71+
72+
# Count models that made a prediction (not NaN)
73+
valid_predictions = ~torch.isnan(class_predictions)
74+
num_valid_predictions = valid_predictions.sum().item()
75+
76+
# If no valid predictions, skip this class
77+
if num_valid_predictions == 0:
78+
continue
79+
80+
# Count positive predictions (assuming positive is > 0)
81+
positive_predictions = class_predictions > 0
82+
num_positive = (positive_predictions & valid_predictions).sum().item()
83+
84+
# If majority of models that made a prediction are positive, add this class
85+
if num_positive > num_valid_predictions / 2:
86+
smiles_result.append(predicted_classes[j])
87+
88+
result.append(smiles_result)
89+
90+
return result
91+
92+
def normalize_smiles_list(self, smiles_list):
93+
new = []
94+
print(f"Normalizing SMILES strings...")
95+
for smiles in tqdm.tqdm(smiles_list):
96+
try:
97+
mol = Chem.MolFromSmiles(smiles)
98+
canonical_smiles = Chem.MolToSmiles(mol)
99+
except Exception as e:
100+
print(f"Failed to parse SMILES '{smiles}': {e}")
101+
canonical_smiles = None
102+
new.append(canonical_smiles)
103+
return new
104+
105+
def predict_smiles_list(self, smiles_list) -> list:
106+
#smiles_list = self.normalize_smiles_list(smiles_list)
107+
ordered_predictions, predicted_classes = self.gather_predictions(smiles_list)
108+
aggregated_predictions = self.aggregate_predictions(ordered_predictions, predicted_classes)
109+
return aggregated_predictions

chebifier/prediction_models/__init__.py

Whitespace-only changes.
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from abc import ABC
2+
3+
4+
class BasePredictor(ABC):
5+
6+
def __init__(self, model_name: str, **kwargs):
7+
self.model_name = model_name
8+
9+
def predict_smiles_list(self, smiles_list: list[str]) -> dict:
10+
raise NotImplementedError
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from chebifier.prediction_models.nn_predictor import NNPredictor
2+
from chebai.models.electra import Electra
3+
from chebai.preprocessing.reader import ChemDataReader
4+
5+
6+
class ElectraPredictor(NNPredictor):
7+
8+
def __init__(self, model_name: str, ckpt_path: str, **kwargs):
9+
super().__init__(model_name, ckpt_path, reader_cls=ChemDataReader, **kwargs)
10+
print(f"Initialised Electra model {self.model_name} (device: {self.device})")
11+
12+
def init_model(self, ckpt_path: str, **kwargs) -> Electra:
13+
model = Electra.load_from_checkpoint(
14+
ckpt_path,
15+
map_location=self.device,
16+
criterion=None, strict=False,
17+
metrics=dict(train=dict(), test=dict(), validation=dict()), pretrained_checkpoint=None
18+
)
19+
model.eval()
20+
return model
21+
22+
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import tqdm
2+
3+
from chebifier.prediction_models.base_predictor import BasePredictor
4+
from rdkit import Chem
5+
import numpy as np
6+
import torch
7+
8+
class NNPredictor(BasePredictor):
9+
10+
def __init__(self, model_name: str, ckpt_path: str, reader_cls, target_labels_path: str, **kwargs):
11+
super().__init__(model_name, **kwargs)
12+
self.reader_cls = reader_cls
13+
14+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15+
self.model = self.init_model(ckpt_path=ckpt_path)
16+
self.target_labels = [line.strip() for line in open(target_labels_path, encoding="utf-8")]
17+
self.batch_size = kwargs.get("batch_size", 1)
18+
19+
20+
def init_model(self, ckpt_path: str, **kwargs):
21+
raise NotImplementedError("Model initialization must be implemented in subclasses.")
22+
23+
def calculate_results(self, batch):
24+
collator = self.reader_cls.COLLATOR()
25+
dat = self.model._process_batch(collator(batch).to(self.device), 0)
26+
return self.model(dat, **dat["model_kwargs"])
27+
28+
def batchify(self, batch):
29+
cache = []
30+
for r in batch:
31+
cache.append(r)
32+
if len(cache) >= self.batch_size:
33+
yield cache
34+
cache = []
35+
if cache:
36+
yield cache
37+
38+
def read_smiles(self, smiles):
39+
reader = self.reader_cls()
40+
d = reader.to_data(dict(features=smiles, labels=None))
41+
return d
42+
43+
def predict_smiles_list(self, smiles_list) -> list:
44+
"""Returns a list with the length of smiles_list, each element is either None (=failure) or a dictionary
45+
Of classes and predicted values."""
46+
token_dicts = []
47+
could_not_parse = []
48+
index_map = dict()
49+
for i, smiles in enumerate(smiles_list):
50+
try:
51+
# Try to parse the smiles string
52+
if not smiles:
53+
raise ValueError()
54+
d = self.read_smiles(smiles)
55+
# This is just for sanity checks
56+
rdmol = Chem.MolFromSmiles(smiles, sanitize=False)
57+
except Exception as e:
58+
# Note if it fails
59+
could_not_parse.append(i)
60+
print(f"Failing to parse {smiles} due to {e}")
61+
else:
62+
if rdmol is None:
63+
could_not_parse.append(i)
64+
else:
65+
index_map[i] = len(token_dicts)
66+
token_dicts.append(d)
67+
results = []
68+
if token_dicts:
69+
for batch in tqdm.tqdm(self.batchify(token_dicts), desc=f"{self.model_name}", total=len(token_dicts)//self.batch_size):
70+
result = self.calculate_results(batch)
71+
if isinstance(result, dict) and "logits" in result:
72+
result = result["logits"]
73+
results += result.cpu().detach().tolist()
74+
results = np.stack(results, axis=0)
75+
preds = [{self.target_labels[j]: p for j, p in enumerate(results[index_map[i]])}
76+
if i not in could_not_parse else None for i in range(len(smiles_list))]
77+
return preds
78+
else:
79+
return [None for _ in smiles_list]

0 commit comments

Comments
 (0)