Skip to content

Commit c30e06d

Browse files
committed
reformat with black, add development dependencies
1 parent 7ebbacb commit c30e06d

File tree

10 files changed

+213
-81
lines changed

10 files changed

+213
-81
lines changed

chebifier/cli.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,19 +72,26 @@ def predict(
7272
use_confidence,
7373
resolve_inconsistencies=True,
7474
):
75-
"""Predict ChEBI classes for SMILES strings using an ensemble model.
76-
"""
75+
"""Predict ChEBI classes for SMILES strings using an ensemble model."""
7776
# Load configuration from YAML file
7877
if not ensemble_config:
7978
print(f"Using default ensemble configuration")
80-
with importlib.resources.files("chebifier").joinpath("ensemble.yml").open("r") as f:
79+
with (
80+
importlib.resources.files("chebifier")
81+
.joinpath("ensemble.yml")
82+
.open("r") as f
83+
):
8184
config = yaml.safe_load(f)
8285
else:
8386
print(f"Loading ensemble configuration from {ensemble_config}")
8487
with open(ensemble_config, "r") as f:
8588
config = yaml.safe_load(f)
8689

87-
with importlib.resources.files("chebifier").joinpath("model_registry.yml").open("r") as f:
90+
with (
91+
importlib.resources.files("chebifier")
92+
.joinpath("model_registry.yml")
93+
.open("r") as f
94+
):
8895
model_registry = yaml.safe_load(f)
8996

9097
new_config = {}
@@ -101,7 +108,11 @@ def predict(
101108
config = new_config
102109

103110
# Instantiate ensemble model
104-
ensemble = ENSEMBLES[ensemble_type](config, chebi_version=chebi_version, resolve_inconsistencies=resolve_inconsistencies)
111+
ensemble = ENSEMBLES[ensemble_type](
112+
config,
113+
chebi_version=chebi_version,
114+
resolve_inconsistencies=resolve_inconsistencies,
115+
)
105116

106117
# Collect SMILES strings from arguments and/or file
107118
smiles_list = list(smiles)

chebifier/ensemble/base_ensemble.py

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@
1212

1313
class BaseEnsemble:
1414

15-
def __init__(self, model_configs: dict, chebi_version: int = 241, resolve_inconsistencies: bool = True):
15+
def __init__(
16+
self,
17+
model_configs: dict,
18+
chebi_version: int = 241,
19+
resolve_inconsistencies: bool = True,
20+
):
1621
# Deferred Import: To avoid circular import error
1722
from chebifier.model_registry import MODEL_TYPES
1823

@@ -28,33 +33,43 @@ def __init__(self, model_configs: dict, chebi_version: int = 241, resolve_incons
2833
if os.path.isfile(file):
2934
self.disjoint_files.append(file)
3035
else:
31-
print(f"Disjoint axiom file {file} not found. Loading from huggingface instead...")
36+
print(
37+
f"Disjoint axiom file {file} not found. Loading from huggingface instead..."
38+
)
3239
from chebifier.hugging_face import download_model_files
33-
self.disjoint_files.append(download_model_files({
34-
"repo_id": "chebai/chebifier",
35-
"repo_type": "dataset",
36-
"files": {"disjoint_file": os.path.basename(file)},
37-
})["disjoint_file"])
40+
41+
self.disjoint_files.append(
42+
download_model_files(
43+
{
44+
"repo_id": "chebai/chebifier",
45+
"repo_type": "dataset",
46+
"files": {"disjoint_file": os.path.basename(file)},
47+
}
48+
)["disjoint_file"]
49+
)
3850

3951
self.models = []
4052
self.positive_prediction_threshold = 0.5
4153
for model_name, model_config in model_configs.items():
4254
model_cls = MODEL_TYPES[model_config["type"]]
4355
if "hugging_face" in model_config:
4456
from chebifier.hugging_face import download_model_files
57+
4558
hugging_face_kwargs = download_model_files(model_config["hugging_face"])
4659
else:
4760
hugging_face_kwargs = {}
4861
if "package_name" in model_config:
4962
check_package_installed(model_config["package_name"])
5063

5164
model_instance = model_cls(
52-
model_name, **model_config, **hugging_face_kwargs, chebi_graph=self.chebi_graph
65+
model_name,
66+
**model_config,
67+
**hugging_face_kwargs,
68+
chebi_graph=self.chebi_graph,
5369
)
5470
assert isinstance(model_instance, BasePredictor)
5571
self.models.append(model_instance)
5672

57-
5873
if resolve_inconsistencies:
5974
self.smoother = PredictionSmoother(
6075
self.chebi_dataset,
@@ -96,7 +111,9 @@ def gather_predictions(self, smiles_list):
96111

97112
return ordered_logits, predicted_classes
98113

99-
def consolidate_predictions(self, predictions, classwise_weights, predicted_classes, **kwargs):
114+
def consolidate_predictions(
115+
self, predictions, classwise_weights, predicted_classes, **kwargs
116+
):
100117
"""
101118
Aggregates predictions from multiple models using weighted majority voting.
102119
Optimized version using tensor operations instead of for loops.
@@ -152,9 +169,13 @@ def consolidate_predictions(self, predictions, classwise_weights, predicted_clas
152169
if self.smoother is not None:
153170
self.smoother.set_label_names(class_names)
154171
smooth_net_score = self.smoother(net_score)
155-
class_decisions = (smooth_net_score > 0.5) & has_valid_predictions # Shape: (num_smiles, num_classes)
172+
class_decisions = (
173+
smooth_net_score > 0.5
174+
) & has_valid_predictions # Shape: (num_smiles, num_classes)
156175
else:
157-
class_decisions = (net_score > 0) & has_valid_predictions # Shape: (num_smiles, num_classes)
176+
class_decisions = (
177+
net_score > 0
178+
) & has_valid_predictions # Shape: (num_smiles, num_classes)
158179
end_time = time.perf_counter()
159180
print(f"Prediction smoothing took {end_time - start_time:.2f} seconds")
160181

@@ -178,7 +199,9 @@ def predict_smiles_list(
178199
smiles_list
179200
)
180201
if len(predicted_classes) == 0:
181-
print(f"Warning: No classes have been predicted for the given SMILES list.")
202+
print(
203+
f"Warning: No classes have been predicted for the given SMILES list."
204+
)
182205
# save predictions
183206
torch.save(ordered_predictions, preds_file)
184207
with open(predicted_classes_file, "w") as f:
@@ -203,7 +226,14 @@ def predict_smiles_list(
203226
class_names = list(predicted_classes.keys())
204227
class_indices = {predicted_classes[cls]: cls for cls in class_names}
205228
result = [
206-
[class_indices[idx.item()] for idx in torch.nonzero(i, as_tuple=True)[0]] if not failure else None
229+
(
230+
[
231+
class_indices[idx.item()]
232+
for idx in torch.nonzero(i, as_tuple=True)[0]
233+
]
234+
if not failure
235+
else None
236+
)
207237
for i, failure in zip(class_decisions, is_failure)
208238
]
209239

@@ -240,7 +270,11 @@ def predict_smiles_list(
240270
}
241271
)
242272
r = ensemble.predict_smiles_list(
243-
["[NH3+]CCCC[C@H](NC(=O)[C@@H]([NH3+])CC([O-])=O)C([O-])=O", "C[C@H](N)C(=O)NCC(O)=O#", ""],
273+
[
274+
"[NH3+]CCCC[C@H](NC(=O)[C@@H]([NH3+])CC([O-])=O)C([O-])=O",
275+
"C[C@H](N)C(=O)NCC(O)=O#",
276+
"",
277+
],
244278
load_preds_if_possible=False,
245279
)
246280
print(len(r), r[0])

chebifier/model_registry.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@
77
ChemlogPeptidesPredictor,
88
ElectraPredictor,
99
ResGatedPredictor,
10-
ChEBILookupPredictor
10+
ChEBILookupPredictor,
1111
)
1212
from chebifier.prediction_models.c3p_predictor import C3PPredictor
13-
from chebifier.prediction_models.chemlog_predictor import ChemlogXMolecularEntityPredictor, ChemlogOrganoXCompoundPredictor
13+
from chebifier.prediction_models.chemlog_predictor import (
14+
ChemlogXMolecularEntityPredictor,
15+
ChemlogOrganoXCompoundPredictor,
16+
)
1417

1518
ENSEMBLES = {
1619
"mv": BaseEnsemble,
@@ -26,7 +29,7 @@
2629
"chebi_lookup": ChEBILookupPredictor,
2730
"chemlog_element": ChemlogXMolecularEntityPredictor,
2831
"chemlog_organox": ChemlogOrganoXCompoundPredictor,
29-
"c3p": C3PPredictor
32+
"c3p": C3PPredictor,
3033
}
3134

3235

chebifier/prediction_models/__init__.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,12 @@
33
from .electra_predictor import ElectraPredictor
44
from .gnn_predictor import ResGatedPredictor
55
from .chebi_lookup import ChEBILookupPredictor
6-
__all__ = ["BasePredictor", "ChemlogPeptidesPredictor", "ElectraPredictor", "ResGatedPredictor", "ChEBILookupPredictor",
7-
"ChemlogExtraPredictor"]
6+
7+
__all__ = [
8+
"BasePredictor",
9+
"ChemlogPeptidesPredictor",
10+
"ElectraPredictor",
11+
"ResGatedPredictor",
12+
"ChEBILookupPredictor",
13+
"ChemlogExtraPredictor",
14+
]

chebifier/prediction_models/c3p_predictor.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,37 @@ class C3PPredictor(BasePredictor):
1212
Wrapper for C3P (url).
1313
"""
1414

15-
def __init__(self, model_name: str, program_directory: Optional[Path]=None, chemical_classes: Optional[List[str]]=None, **kwargs):
15+
def __init__(
16+
self,
17+
model_name: str,
18+
program_directory: Optional[Path] = None,
19+
chemical_classes: Optional[List[str]] = None,
20+
**kwargs,
21+
):
1622
super().__init__(model_name, **kwargs)
1723
self.program_directory = program_directory
1824
self.chemical_classes = chemical_classes
1925
self.chebi_graph = kwargs.get("chebi_graph", None)
2026

2127
@lru_cache(maxsize=100)
2228
def predict_smiles_tuple(self, smiles_list: tuple[str]) -> list:
23-
result_list = c3p_classifier.classify(list(smiles_list), self.program_directory, self.chemical_classes, strict=False)
29+
result_list = c3p_classifier.classify(
30+
list(smiles_list),
31+
self.program_directory,
32+
self.chemical_classes,
33+
strict=False,
34+
)
2435
result_reformatted = [dict() for _ in range(len(smiles_list))]
2536
for result in result_list:
2637
chebi_id = result.class_id.split(":")[1]
27-
result_reformatted[smiles_list.index(result.input_smiles)][chebi_id] = result.is_match
38+
result_reformatted[smiles_list.index(result.input_smiles)][
39+
chebi_id
40+
] = result.is_match
2841
if result.is_match and self.chebi_graph is not None:
2942
for parent in list(self.chebi_graph.predecessors(int(chebi_id))):
30-
result_reformatted[smiles_list.index(result.input_smiles)][str(parent)] = 1
43+
result_reformatted[smiles_list.index(result.input_smiles)][
44+
str(parent)
45+
] = 1
3146
return result_reformatted
3247

3348
def explain_smiles(self, smiles):
@@ -36,14 +51,22 @@ def explain_smiles(self, smiles):
3651
than 300 classes, only take the positive ones.
3752
"""
3853
highlights = []
39-
result_list = c3p_classifier.classify([smiles], self.program_directory, self.chemical_classes, strict=False)
54+
result_list = c3p_classifier.classify(
55+
[smiles], self.program_directory, self.chemical_classes, strict=False
56+
)
4057
for result in result_list:
4158
if result.is_match:
4259
highlights.append(
43-
("text", f"For class {result.class_name} ({result.class_id}), C3P gave the following explanation: {result.reason}")
60+
(
61+
"text",
62+
f"For class {result.class_name} ({result.class_id}), C3P gave the following explanation: {result.reason}",
63+
)
4464
)
4565
highlights = [
46-
("text", f"C3P made positive predictions for {len(highlights)} classes. The explanations are as follows:")
47-
] + highlights
66+
(
67+
"text",
68+
f"C3P made positive predictions for {len(highlights)} classes. The explanations are as follows:",
69+
)
70+
] + highlights
4871

49-
return {"highlights": highlights}
72+
return {"highlights": highlights}

0 commit comments

Comments
 (0)