Skip to content

Commit 90aedd4

Browse files
committed
reformat with black
1 parent f8583cb commit 90aedd4

File tree

3 files changed

+148
-67
lines changed

3 files changed

+148
-67
lines changed

chebifier/__main__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
from chebifier.cli import cli
22

3-
if __name__ == '__main__':
4-
cli()
3+
if __name__ == "__main__":
4+
cli()

chebifier/cli.py

Lines changed: 68 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,49 +5,99 @@
55

66
from .model_registry import ENSEMBLES
77

8+
89
@click.group()
910
def cli():
1011
"""Command line interface for Chebifier."""
1112
pass
1213

14+
1315
@cli.command()
14-
@click.option('--config_file', type=click.Path(exists=True), default=os.path.join('configs', 'huggingface_config.yml'), help="Configuration file for ensemble models")
15-
@click.option('--smiles', '-s', multiple=True, help='SMILES strings to predict')
16-
@click.option('--smiles-file', '-f', type=click.Path(exists=True), help='File containing SMILES strings (one per line)')
17-
@click.option('--output', '-o', type=click.Path(), help='Output file to save predictions (optional)')
18-
@click.option('--ensemble-type', '-e', type=click.Choice(ENSEMBLES.keys()), default='mv', help='Type of ensemble to use (default: Majority Voting)')
19-
@click.option("--chebi-version", "-v", type=int, default=241, help="ChEBI version to use for checking consistency (default: 241)")
20-
@click.option("--use-confidence", "-c", is_flag=True, default=True, help="Weight predictions based on how 'confident' a model is in its prediction (default: True)")
21-
def predict(config_file, smiles, smiles_file, output, ensemble_type, chebi_version, use_confidence):
16+
@click.option(
17+
"--config_file",
18+
type=click.Path(exists=True),
19+
default=os.path.join("configs", "huggingface_config.yml"),
20+
help="Configuration file for ensemble models",
21+
)
22+
@click.option("--smiles", "-s", multiple=True, help="SMILES strings to predict")
23+
@click.option(
24+
"--smiles-file",
25+
"-f",
26+
type=click.Path(exists=True),
27+
help="File containing SMILES strings (one per line)",
28+
)
29+
@click.option(
30+
"--output",
31+
"-o",
32+
type=click.Path(),
33+
help="Output file to save predictions (optional)",
34+
)
35+
@click.option(
36+
"--ensemble-type",
37+
"-e",
38+
type=click.Choice(ENSEMBLES.keys()),
39+
default="mv",
40+
help="Type of ensemble to use (default: Majority Voting)",
41+
)
42+
@click.option(
43+
"--chebi-version",
44+
"-v",
45+
type=int,
46+
default=241,
47+
help="ChEBI version to use for checking consistency (default: 241)",
48+
)
49+
@click.option(
50+
"--use-confidence",
51+
"-c",
52+
is_flag=True,
53+
default=True,
54+
help="Weight predictions based on how 'confident' a model is in its prediction (default: True)",
55+
)
56+
def predict(
57+
config_file,
58+
smiles,
59+
smiles_file,
60+
output,
61+
ensemble_type,
62+
chebi_version,
63+
use_confidence,
64+
):
2265
"""Predict ChEBI classes for SMILES strings using an ensemble model.
23-
66+
2467
CONFIG_FILE is the path to a YAML configuration file for the ensemble model.
2568
"""
2669
# Load configuration from YAML file
27-
with open(config_file, 'r') as f:
70+
with open(config_file, "r") as f:
2871
config = yaml.safe_load(f)
29-
72+
3073
# Instantiate ensemble model
3174
ensemble = ENSEMBLES[ensemble_type](config, chebi_version=chebi_version)
32-
75+
3376
# Collect SMILES strings from arguments and/or file
3477
smiles_list = list(smiles)
3578
if smiles_file:
36-
with open(smiles_file, 'r') as f:
79+
with open(smiles_file, "r") as f:
3780
smiles_list.extend([line.strip() for line in f if line.strip()])
38-
81+
3982
if not smiles_list:
4083
click.echo("No SMILES strings provided. Use --smiles or --smiles-file options.")
4184
return
4285

4386
# Make predictions
44-
predictions = ensemble.predict_smiles_list(smiles_list, use_confidence=use_confidence)
87+
predictions = ensemble.predict_smiles_list(
88+
smiles_list, use_confidence=use_confidence
89+
)
4590

4691
if output:
4792
# save as json
4893
import json
49-
with open(output, 'w') as f:
50-
json.dump({smiles: pred for smiles, pred in zip(smiles_list, predictions)}, f, indent=2)
94+
95+
with open(output, "w") as f:
96+
json.dump(
97+
{smiles: pred for smiles, pred in zip(smiles_list, predictions)},
98+
f,
99+
indent=2,
100+
)
51101

52102
else:
53103
# Print results
@@ -59,5 +109,5 @@ def predict(config_file, smiles, smiles_file, output, ensemble_type, chebi_versi
59109
click.echo(" No predictions")
60110

61111

62-
if __name__ == '__main__':
112+
if __name__ == "__main__":
63113
cli()

chebifier/ensemble/base_ensemble.py

Lines changed: 78 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -22,18 +22,19 @@ def __init__(self, model_configs: dict, chebi_version: int = 241):
2222
hugging_face_kwargs = download_model_files(model_config["hugging_face"])
2323
else:
2424
hugging_face_kwargs = {}
25-
model_instance = model_cls(model_name, **model_config, **hugging_face_kwargs)
25+
model_instance = model_cls(
26+
model_name, **model_config, **hugging_face_kwargs
27+
)
2628
assert isinstance(model_instance, BasePredictor)
2729
self.models.append(model_instance)
2830

2931
self.chebi_dataset = ChEBIOver50(chebi_version=chebi_version)
3032
self.chebi_dataset._download_required_data() # download chebi if not already downloaded
31-
self.disjoint_files=[
33+
self.disjoint_files = [
3234
os.path.join("data", "disjoint_chebi.csv"),
33-
os.path.join("data", "disjoint_additional.csv")
35+
os.path.join("data", "disjoint_additional.csv"),
3436
]
3537

36-
3738
def gather_predictions(self, smiles_list):
3839
# get predictions from all models for the SMILES list
3940
# order them by alphabetically by label class
@@ -60,11 +61,12 @@ def gather_predictions(self, smiles_list):
6061
):
6162
if logits_for_smiles is not None:
6263
for cls in logits_for_smiles:
63-
ordered_logits[j, predicted_classes_dict[cls], i] = logits_for_smiles[cls]
64+
ordered_logits[j, predicted_classes_dict[cls], i] = (
65+
logits_for_smiles[cls]
66+
)
6467

6568
return ordered_logits, predicted_classes
6669

67-
6870
def consolidate_predictions(self, predictions, classwise_weights, **kwargs):
6971
"""
7072
Aggregates predictions from multiple models using weighted majority voting.
@@ -80,11 +82,17 @@ def consolidate_predictions(self, predictions, classwise_weights, **kwargs):
8082
has_valid_predictions = valid_counts > 0
8183

8284
# Calculate positive and negative predictions for all classes at once
83-
positive_mask = (predictions > self.positive_prediction_threshold) & valid_predictions
84-
negative_mask = (predictions < self.positive_prediction_threshold) & valid_predictions
85+
positive_mask = (
86+
predictions > self.positive_prediction_threshold
87+
) & valid_predictions
88+
negative_mask = (
89+
predictions < self.positive_prediction_threshold
90+
) & valid_predictions
8591

8692
if "use_confidence" in kwargs and kwargs["use_confidence"]:
87-
confidence = 2 * torch.abs(predictions.nan_to_num() - self.positive_prediction_threshold)
93+
confidence = 2 * torch.abs(
94+
predictions.nan_to_num() - self.positive_prediction_threshold
95+
)
8896
else:
8997
confidence = torch.ones_like(predictions)
9098

@@ -95,18 +103,22 @@ def consolidate_predictions(self, predictions, classwise_weights, **kwargs):
95103
# Calculate weighted predictions using broadcasting
96104
# predictions shape: (num_smiles, num_classes, num_models)
97105
# weights shape: (num_classes, num_models)
98-
positive_weighted = positive_mask.float() * confidence * pos_weights.unsqueeze(0)
99-
negative_weighted = negative_mask.float() * confidence * neg_weights.unsqueeze(0)
106+
positive_weighted = (
107+
positive_mask.float() * confidence * pos_weights.unsqueeze(0)
108+
)
109+
negative_weighted = (
110+
negative_mask.float() * confidence * neg_weights.unsqueeze(0)
111+
)
100112

101113
# Sum over models dimension
102114
positive_sum = positive_weighted.sum(dim=2) # Shape: (num_smiles, num_classes)
103115
negative_sum = negative_weighted.sum(dim=2) # Shape: (num_smiles, num_classes)
104116

105117
# Determine which classes to include for each SMILES
106118
net_score = positive_sum - negative_sum # Shape: (num_smiles, num_classes)
107-
class_decisions = (net_score > 0) & has_valid_predictions # Shape: (num_smiles, num_classes)
108-
109-
119+
class_decisions = (
120+
net_score > 0
121+
) & has_valid_predictions # Shape: (num_smiles, num_classes)
110122

111123
return class_decisions
112124

@@ -117,29 +129,43 @@ def calculate_classwise_weights(self, predicted_classes):
117129

118130
return positive_weights, negative_weights
119131

120-
def predict_smiles_list(self, smiles_list, load_preds_if_possible=True, **kwargs) -> list:
132+
def predict_smiles_list(
133+
self, smiles_list, load_preds_if_possible=True, **kwargs
134+
) -> list:
121135
preds_file = f"predictions_by_model_{'_'.join(model.model_name for model in self.models)}.pt"
122136
predicted_classes_file = f"predicted_classes_{'_'.join(model.model_name for model in self.models)}.txt"
123137
if not load_preds_if_possible or not os.path.isfile(preds_file):
124-
ordered_predictions, predicted_classes = self.gather_predictions(smiles_list)
138+
ordered_predictions, predicted_classes = self.gather_predictions(
139+
smiles_list
140+
)
125141
# save predictions
126142
torch.save(ordered_predictions, preds_file)
127143
with open(predicted_classes_file, "w") as f:
128144
for cls in predicted_classes:
129145
f.write(f"{cls}\n")
130146
predicted_classes = {cls: i for i, cls in enumerate(predicted_classes)}
131147
else:
132-
print(f"Loading predictions from {preds_file} and label indexes from {predicted_classes_file}")
148+
print(
149+
f"Loading predictions from {preds_file} and label indexes from {predicted_classes_file}"
150+
)
133151
ordered_predictions = torch.load(preds_file)
134152
with open(predicted_classes_file, "r") as f:
135-
predicted_classes = {line.strip(): i for i, line in enumerate(f.readlines())}
153+
predicted_classes = {
154+
line.strip(): i for i, line in enumerate(f.readlines())
155+
}
136156

137157
classwise_weights = self.calculate_classwise_weights(predicted_classes)
138-
class_decisions = self.consolidate_predictions(ordered_predictions, classwise_weights, **kwargs)
158+
class_decisions = self.consolidate_predictions(
159+
ordered_predictions, classwise_weights, **kwargs
160+
)
139161
# Smooth predictions
140162
class_names = list(predicted_classes.keys())
141163
# initialise new smoother class since we don't know the labels beforehand (this could be more efficient)
142-
new_smoother = PredictionSmoother(self.chebi_dataset, label_names=class_names, disjoint_files=self.disjoint_files)
164+
new_smoother = PredictionSmoother(
165+
self.chebi_dataset,
166+
label_names=class_names,
167+
disjoint_files=self.disjoint_files,
168+
)
143169
class_decisions = new_smoother(class_decisions)
144170

145171
class_names = list(predicted_classes.keys())
@@ -153,31 +179,36 @@ def predict_smiles_list(self, smiles_list, load_preds_if_possible=True, **kwargs
153179

154180

155181
if __name__ == "__main__":
156-
ensemble = BaseEnsemble({"resgated_0ps1g189":{
157-
"type": "resgated",
158-
"ckpt_path": "data/0ps1g189/epoch=122.ckpt",
159-
"target_labels_path": "data/chebi_v241/ChEBI50/processed/classes.txt",
160-
"molecular_properties": [
161-
"chebai_graph.preprocessing.properties.AtomType",
162-
"chebai_graph.preprocessing.properties.NumAtomBonds",
163-
"chebai_graph.preprocessing.properties.AtomCharge",
164-
"chebai_graph.preprocessing.properties.AtomAromaticity",
165-
"chebai_graph.preprocessing.properties.AtomHybridization",
166-
"chebai_graph.preprocessing.properties.AtomNumHs",
167-
"chebai_graph.preprocessing.properties.BondType",
168-
"chebai_graph.preprocessing.properties.BondInRing",
169-
"chebai_graph.preprocessing.properties.BondAromaticity",
170-
"chebai_graph.preprocessing.properties.RDKit2DNormalized",
171-
],
172-
#"classwise_weights_path" : "../python-chebai/metrics_0ps1g189_80-10-10.json"
173-
},
174-
175-
"electra_14ko0zcf": {
176-
"type": "electra",
177-
"ckpt_path": "data/14ko0zcf/epoch=193.ckpt",
178-
"target_labels_path": "data/chebi_v241/ChEBI50/processed/classes.txt",
179-
#"classwise_weights_path": "../python-chebai/metrics_electra_14ko0zcf_80-10-10.json",
180-
}
181-
})
182-
r = ensemble.predict_smiles_list(["[NH3+]CCCC[C@H](NC(=O)[C@@H]([NH3+])CC([O-])=O)C([O-])=O"], load_preds_if_possible=False)
182+
ensemble = BaseEnsemble(
183+
{
184+
"resgated_0ps1g189": {
185+
"type": "resgated",
186+
"ckpt_path": "data/0ps1g189/epoch=122.ckpt",
187+
"target_labels_path": "data/chebi_v241/ChEBI50/processed/classes.txt",
188+
"molecular_properties": [
189+
"chebai_graph.preprocessing.properties.AtomType",
190+
"chebai_graph.preprocessing.properties.NumAtomBonds",
191+
"chebai_graph.preprocessing.properties.AtomCharge",
192+
"chebai_graph.preprocessing.properties.AtomAromaticity",
193+
"chebai_graph.preprocessing.properties.AtomHybridization",
194+
"chebai_graph.preprocessing.properties.AtomNumHs",
195+
"chebai_graph.preprocessing.properties.BondType",
196+
"chebai_graph.preprocessing.properties.BondInRing",
197+
"chebai_graph.preprocessing.properties.BondAromaticity",
198+
"chebai_graph.preprocessing.properties.RDKit2DNormalized",
199+
],
200+
# "classwise_weights_path" : "../python-chebai/metrics_0ps1g189_80-10-10.json"
201+
},
202+
"electra_14ko0zcf": {
203+
"type": "electra",
204+
"ckpt_path": "data/14ko0zcf/epoch=193.ckpt",
205+
"target_labels_path": "data/chebi_v241/ChEBI50/processed/classes.txt",
206+
# "classwise_weights_path": "../python-chebai/metrics_electra_14ko0zcf_80-10-10.json",
207+
},
208+
}
209+
)
210+
r = ensemble.predict_smiles_list(
211+
["[NH3+]CCCC[C@H](NC(=O)[C@@H]([NH3+])CC([O-])=O)C([O-])=O"],
212+
load_preds_if_possible=False,
213+
)
183214
print(len(r), r[0])

0 commit comments

Comments
 (0)