Skip to content

Commit 9c3beea

Browse files
committed
pre-commit -run -a
1 parent 997120e commit 9c3beea

File tree

7 files changed

+169
-77
lines changed

7 files changed

+169
-77
lines changed

chebifier/cli.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,66 @@
1-
2-
3-
41
import click
52
import yaml
6-
import sys
3+
74
from chebifier.ensemble.base_ensemble import BaseEnsemble
8-
from chebifier.ensemble.weighted_majority_ensemble import WMVwithPPVNPVEnsemble, WMVwithF1Ensemble
5+
from chebifier.ensemble.weighted_majority_ensemble import (
6+
WMVwithF1Ensemble,
7+
WMVwithPPVNPVEnsemble,
8+
)
99

1010

1111
@click.group()
1212
def cli():
1313
"""Command line interface for Chebifier."""
1414
pass
1515

16+
1617
ENSEMBLES = {
1718
"mv": BaseEnsemble,
1819
"wmv-ppvnpv": WMVwithPPVNPVEnsemble,
19-
"wmv-f1": WMVwithF1Ensemble
20+
"wmv-f1": WMVwithF1Ensemble,
2021
}
2122

23+
2224
@cli.command()
23-
@click.argument('config_file', type=click.Path(exists=True))
24-
@click.option('--smiles', '-s', multiple=True, help='SMILES strings to predict')
25-
@click.option('--smiles-file', '-f', type=click.Path(exists=True), help='File containing SMILES strings (one per line)')
26-
@click.option('--output', '-o', type=click.Path(), help='Output file to save predictions (optional)')
27-
@click.option('--ensemble-type', '-e', type=click.Choice(ENSEMBLES.keys()), default='mv', help='Type of ensemble to use (default: Majority Voting)')
25+
@click.argument("config_file", type=click.Path(exists=True))
26+
@click.option("--smiles", "-s", multiple=True, help="SMILES strings to predict")
27+
@click.option(
28+
"--smiles-file",
29+
"-f",
30+
type=click.Path(exists=True),
31+
help="File containing SMILES strings (one per line)",
32+
)
33+
@click.option(
34+
"--output",
35+
"-o",
36+
type=click.Path(),
37+
help="Output file to save predictions (optional)",
38+
)
39+
@click.option(
40+
"--ensemble-type",
41+
"-e",
42+
type=click.Choice(ENSEMBLES.keys()),
43+
default="mv",
44+
help="Type of ensemble to use (default: Majority Voting)",
45+
)
2846
def predict(config_file, smiles, smiles_file, output, ensemble_type):
2947
"""Predict ChEBI classes for SMILES strings using an ensemble model.
30-
48+
3149
CONFIG_FILE is the path to a YAML configuration file for the ensemble model.
3250
"""
3351
# Load configuration from YAML file
34-
with open(config_file, 'r') as f:
52+
with open(config_file, "r") as f:
3553
config = yaml.safe_load(f)
36-
54+
3755
# Instantiate ensemble model
3856
ensemble = ENSEMBLES[ensemble_type](config)
39-
57+
4058
# Collect SMILES strings from arguments and/or file
4159
smiles_list = list(smiles)
4260
if smiles_file:
43-
with open(smiles_file, 'r') as f:
61+
with open(smiles_file, "r") as f:
4462
smiles_list.extend([line.strip() for line in f if line.strip()])
45-
63+
4664
if not smiles_list:
4765
click.echo("No SMILES strings provided. Use --smiles or --smiles-file options.")
4866
return
@@ -53,8 +71,13 @@ def predict(config_file, smiles, smiles_file, output, ensemble_type):
5371
if output:
5472
# save as json
5573
import json
56-
with open(output, 'w') as f:
57-
json.dump({smiles: pred for smiles, pred in zip(smiles_list, predictions)}, f, indent=2)
74+
75+
with open(output, "w") as f:
76+
json.dump(
77+
{smiles: pred for smiles, pred in zip(smiles_list, predictions)},
78+
f,
79+
indent=2,
80+
)
5881

5982
else:
6083
# Print results
@@ -66,5 +89,5 @@ def predict(config_file, smiles, smiles_file, output, ensemble_type):
6689
click.echo(" No predictions")
6790

6891

69-
if __name__ == '__main__':
92+
if __name__ == "__main__":
7093
cli()

chebifier/ensemble/weighted_majority_ensemble.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33
from chebifier.ensemble.base_ensemble import BaseEnsemble
44

55

6-
76
class WMVwithPPVNPVEnsemble(BaseEnsemble):
8-
97
def calculate_classwise_weights(self, predicted_classes):
108
"""
119
Given the positions of predicted classes in the predictions tensor, assign weights to each class. The
@@ -23,15 +21,18 @@ def calculate_classwise_weights(self, predicted_classes):
2321
positive_weights[predicted_classes[cls], j] *= weights["PPV"]
2422
negative_weights[predicted_classes[cls], j] *= weights["NPV"]
2523

26-
print(f"Calculated model weightings. The averages for positive / negative weights are:")
24+
print(
25+
"Calculated model weightings. The averages for positive / negative weights are:"
26+
)
2727
for i, model in enumerate(self.models):
28-
print(f"{model.model_name}: {positive_weights[:, i].mean().item():.3f} / {negative_weights[:, i].mean().item():.3f}")
28+
print(
29+
f"{model.model_name}: {positive_weights[:, i].mean().item():.3f} / {negative_weights[:, i].mean().item():.3f}"
30+
)
2931

3032
return positive_weights, negative_weights
3133

3234

3335
class WMVwithF1Ensemble(BaseEnsemble):
34-
3536
def calculate_classwise_weights(self, predicted_classes):
3637
"""
3738
Given the positions of predicted classes in the predictions tensor, assign weights to each class. The
@@ -45,11 +46,15 @@ def calculate_classwise_weights(self, predicted_classes):
4546
continue
4647
for cls, weights in model.classwise_weights.items():
4748
if (2 * weights["TP"] + weights["FP"] + weights["FN"]) > 0:
48-
f1 = 2 * weights["TP"] / (2 * weights["TP"] + weights["FP"] + weights["FN"])
49+
f1 = (
50+
2
51+
* weights["TP"]
52+
/ (2 * weights["TP"] + weights["FP"] + weights["FN"])
53+
)
4954
weights_by_cls[predicted_classes[cls], j] *= f1
5055

51-
print(f"Calculated model weightings. The average weights are:")
56+
print("Calculated model weightings. The average weights are:")
5257
for i, model in enumerate(self.models):
5358
print(f"{model.model_name}: {weights_by_cls[:, i].mean().item():.3f}")
5459

55-
return weights_by_cls, weights_by_cls
60+
return weights_by_cls, weights_by_cls
Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,24 @@
1-
from abc import ABC
21
import json
2+
from abc import ABC
3+
34

45
class BasePredictor(ABC):
56

6-
def __init__(self, model_name: str, model_weight: int = 1, classwise_weights_path: str = None, **kwargs):
7+
def __init__(
8+
self,
9+
model_name: str,
10+
model_weight: int = 1,
11+
classwise_weights_path: str = None,
12+
**kwargs
13+
):
714
self.model_name = model_name
815
self.model_weight = model_weight
916
if classwise_weights_path is not None:
10-
self.classwise_weights = json.load(open(classwise_weights_path, encoding="utf-8"))
17+
self.classwise_weights = json.load(
18+
open(classwise_weights_path, encoding="utf-8")
19+
)
1120
else:
1221
self.classwise_weights = None
1322

14-
1523
def predict_smiles_list(self, smiles_list: list[str]) -> dict:
16-
raise NotImplementedError
24+
raise NotImplementedError
Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
11
import tqdm
2+
from chemlog.cli import CLASSIFIERS, _smiles_to_mol, strategy_call
23

34
from chebifier.prediction_models.base_predictor import BasePredictor
4-
from chemlog.alg_classification.charge_classifier import AlgChargeClassifier
5-
from chemlog.alg_classification.peptide_size_classifier import AlgPeptideSizeClassifier
6-
from chemlog.alg_classification.proteinogenics_classifier import AlgProteinogenicsClassifier
7-
from chemlog.alg_classification.substructure_classifier import AlgSubstructureClassifier
8-
from chemlog.cli import strategy_call, _smiles_to_mol, CLASSIFIERS
95

10-
class ChemLogPredictor(BasePredictor):
116

7+
class ChemLogPredictor(BasePredictor):
128
def __init__(self, model_name: str, **kwargs):
139
super().__init__(model_name, **kwargs)
1410
self.strategy = "algo"
1511
self.classifier_instances = {
1612
k: v() for k, v in CLASSIFIERS[self.strategy].items()
1713
}
18-
self.peptide_labels = ["15841", "16670", "24866", "25676", "25696", "25697", "27369", "46761", "47923",
19-
"48030", "48545", "60194", "60334", "60466", "64372", "65061", "90799", "155837"]
20-
14+
# fmt: off
15+
self.peptide_labels = [
16+
"15841", "16670", "24866", "25676", "25696", "25697", "27369", "46761", "47923",
17+
"48030", "48545", "60194", "60334", "60466", "64372", "65061", "90799", "155837"
18+
]
19+
# fmt: on
2120
print(f"Initialised ChemLog model {self.model_name}")
2221

2322
def predict_smiles_list(self, smiles_list: list[str]) -> list:
@@ -27,9 +26,21 @@ def predict_smiles_list(self, smiles_list: list[str]) -> list:
2726
if mol is None:
2827
results.append(None)
2928
else:
30-
results.append({label: 1 if label in strategy_call(self.strategy, self.classifier_instances, mol)["chebi_classes"] else 0 for label in self.peptide_labels})
29+
results.append(
30+
{
31+
label: (
32+
1
33+
if label
34+
in strategy_call(
35+
self.strategy, self.classifier_instances, mol
36+
)["chebi_classes"]
37+
else 0
38+
)
39+
for label in self.peptide_labels
40+
}
41+
)
3142

3243
for classifier in self.classifier_instances.values():
3344
classifier.on_finish()
3445

35-
return results
46+
return results

chebifier/prediction_models/electra_predictor.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from chebifier.prediction_models.nn_predictor import NNPredictor
21
from chebai.models.electra import Electra
32
from chebai.preprocessing.reader import ChemDataReader
43

4+
from chebifier.prediction_models.nn_predictor import NNPredictor
5+
56

67
class ElectraPredictor(NNPredictor):
78

@@ -13,10 +14,10 @@ def init_model(self, ckpt_path: str, **kwargs) -> Electra:
1314
model = Electra.load_from_checkpoint(
1415
ckpt_path,
1516
map_location=self.device,
16-
criterion=None, strict=False,
17-
metrics=dict(train=dict(), test=dict(), validation=dict()), pretrained_checkpoint=None
17+
criterion=None,
18+
strict=False,
19+
metrics=dict(train=dict(), test=dict(), validation=dict()),
20+
pretrained_checkpoint=None,
1821
)
1922
model.eval()
2023
return model
21-
22-

chebifier/prediction_models/gnn_predictor.py

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
1-
from chebifier.prediction_models.nn_predictor import NNPredictor
21
import chebai_graph.preprocessing.properties as p
32
import torch
43
from chebai_graph.models.graph import ResGatedGraphConvNetGraphPred
5-
from chebai_graph.preprocessing.reader import GraphPropertyReader
64
from chebai_graph.preprocessing.property_encoder import IndexEncoder, OneHotEncoder
5+
from chebai_graph.preprocessing.reader import GraphPropertyReader
76
from torch_geometric.data.data import Data as GeomData
87

8+
from chebifier.prediction_models.nn_predictor import NNPredictor
9+
910

1011
class ResGatedPredictor(NNPredictor):
1112

1213
def __init__(self, model_name: str, ckpt_path: str, molecular_properties, **kwargs):
13-
super().__init__(model_name, ckpt_path, reader_cls=GraphPropertyReader, **kwargs)
14+
super().__init__(
15+
model_name, ckpt_path, reader_cls=GraphPropertyReader, **kwargs
16+
)
1417
# molecular_properties is a list of class paths
1518
if molecular_properties is not None:
1619
properties = [self.load_class(prop)() for prop in molecular_properties]
@@ -32,11 +35,23 @@ def load_class(self, class_path: str):
3235

3336
def init_model(self, ckpt_path: str, **kwargs) -> ResGatedGraphConvNetGraphPred:
3437
model = ResGatedGraphConvNetGraphPred.load_from_checkpoint(
35-
ckpt_path, map_location=torch.device(self.device), criterion=None, strict=False,
36-
metrics=dict(train=dict(), test=dict(), validation=dict()), pretrained_checkpoint=None,
37-
config={"in_length": 256, "hidden_length": 512, "dropout_rate": 0.1, "n_conv_layers": 3,
38-
"n_linear_layers": 3, "n_atom_properties": 158, "n_bond_properties": 7,
39-
"n_molecule_properties": 200})
38+
ckpt_path,
39+
map_location=torch.device(self.device),
40+
criterion=None,
41+
strict=False,
42+
metrics=dict(train=dict(), test=dict(), validation=dict()),
43+
pretrained_checkpoint=None,
44+
config={
45+
"in_length": 256,
46+
"hidden_length": 512,
47+
"dropout_rate": 0.1,
48+
"n_conv_layers": 3,
49+
"n_linear_layers": 3,
50+
"n_atom_properties": 158,
51+
"n_bond_properties": 7,
52+
"n_molecule_properties": 200,
53+
},
54+
)
4055
model.eval()
4156
return model
4257

@@ -55,14 +70,21 @@ def read_smiles(self, smiles):
5570
# use default value if we meet an unseen value
5671
if isinstance(prop.encoder, IndexEncoder):
5772
if str(value) in prop.encoder.cache:
58-
index = prop.encoder.cache.index(str(value)) + prop.encoder.offset
73+
index = (
74+
prop.encoder.cache.index(str(value)) + prop.encoder.offset
75+
)
5976
else:
6077
index = 0
61-
print(f"Unknown property value {value} for property {prop} at smiles {smiles}")
78+
print(
79+
f"Unknown property value {value} for property {prop} at smiles {smiles}"
80+
)
6281
if isinstance(prop.encoder, OneHotEncoder):
63-
encoded_values.append(torch.nn.functional.one_hot(
64-
torch.tensor(index), num_classes=prop.encoder.get_encoding_length()
65-
))
82+
encoded_values.append(
83+
torch.nn.functional.one_hot(
84+
torch.tensor(index),
85+
num_classes=prop.encoder.get_encoding_length(),
86+
)
87+
)
6688
else:
6789
encoded_values.append(torch.tensor([index]))
6890

@@ -77,9 +99,7 @@ def read_smiles(self, smiles):
7799
if len(encoded_values.size()) == 1:
78100
encoded_values = encoded_values.unsqueeze(1)
79101
else:
80-
encoded_values = torch.zeros(
81-
(0, prop.encoder.get_encoding_length())
82-
)
102+
encoded_values = torch.zeros((0, prop.encoder.get_encoding_length()))
83103
if isinstance(prop, p.AtomProperty):
84104
x = torch.cat([x, encoded_values], dim=1)
85105
elif isinstance(prop, p.BondProperty):
@@ -93,4 +113,4 @@ def read_smiles(self, smiles):
93113
edge_attr=edge_attr,
94114
molecule_attr=molecule_attr,
95115
)
96-
return d
116+
return d

0 commit comments

Comments
 (0)