Skip to content

Commit e0b3ca7

Browse files
committed
merge from dev
2 parents d2c586a + fd0b633 commit e0b3ca7

File tree

8 files changed

+453
-12
lines changed

8 files changed

+453
-12
lines changed

chebifier/ensemble/base_ensemble.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def __init__(self, model_configs: dict):
1616
self.positive_prediction_threshold = 0.5
1717
for model_name, model_config in model_configs.items():
1818
model_cls = MODEL_TYPES[model_config["type"]]
19-
model_instance = model_cls(**model_config)
19+
model_instance = model_cls(model_name, **model_config)
2020
assert isinstance(model_instance, BasePredictor)
2121
self.models.append(model_instance)
2222

@@ -73,8 +73,12 @@ def consolidate_predictions(
7373
has_valid_predictions = valid_counts > 0
7474

7575
# Calculate positive and negative predictions for all classes at once
76-
positive_mask = (predictions > 0.5) & valid_predictions
77-
negative_mask = (predictions < 0.5) & valid_predictions
76+
positive_mask = (
77+
predictions > self.positive_prediction_threshold
78+
) & valid_predictions
79+
negative_mask = (
80+
predictions < self.positive_prediction_threshold
81+
) & valid_predictions
7882

7983
confidence = 2 * torch.abs(
8084
predictions.nan_to_num() - self.positive_prediction_threshold
@@ -134,6 +138,7 @@ def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list:
134138
with open(predicted_classes_file, "w") as f:
135139
for cls in predicted_classes:
136140
f.write(f"{cls}\n")
141+
predicted_classes = {cls: i for i, cls in enumerate(predicted_classes)}
137142
else:
138143
print(
139144
f"Loading predictions from {preds_file} and label indexes from {predicted_classes_file}"
@@ -149,3 +154,39 @@ def predict_smiles_list(self, smiles_list, load_preds_if_possible=True) -> list:
149154
ordered_predictions, predicted_classes, classwise_weights
150155
)
151156
return aggregated_predictions
157+
158+
159+
if __name__ == "__main__":
160+
ensemble = BaseEnsemble(
161+
{
162+
"resgated_0ps1g189": {
163+
"type": "resgated",
164+
"ckpt_path": "../python-chebai/logs/downloaded_ckpts/electra_resgated_comp/resgated_80-10-10_0ps1g189_epoch=122.ckpt",
165+
"target_labels_path": "../python-chebai/data/chebi_v241/ChEBI50/processed/classes.txt",
166+
"molecular_properties": [
167+
"chebai_graph.preprocessing.properties.AtomType",
168+
"chebai_graph.preprocessing.properties.NumAtomBonds",
169+
"chebai_graph.preprocessing.properties.AtomCharge",
170+
"chebai_graph.preprocessing.properties.AtomAromaticity",
171+
"chebai_graph.preprocessing.properties.AtomHybridization",
172+
"chebai_graph.preprocessing.properties.AtomNumHs",
173+
"chebai_graph.preprocessing.properties.BondType",
174+
"chebai_graph.preprocessing.properties.BondInRing",
175+
"chebai_graph.preprocessing.properties.BondAromaticity",
176+
"chebai_graph.preprocessing.properties.RDKit2DNormalized",
177+
],
178+
"classwise_weights_path": "../python-chebai/metrics_0ps1g189_80-10-10.json",
179+
},
180+
"electra_14ko0zcf": {
181+
"type": "electra",
182+
"ckpt_path": "../python-chebai/logs/downloaded_ckpts/electra_resgated_comp/electra_80-10-10_14ko0zcf_epoch=193.ckpt",
183+
"target_labels_path": "../python-chebai/data/chebi_v241/ChEBI50/processed/classes.txt",
184+
"classwise_weights_path": "../python-chebai/metrics_electra_14ko0zcf_80-10-10.json",
185+
},
186+
}
187+
)
188+
r = ensemble.predict_smiles_list(
189+
["[NH3+]CCCC[C@H](NC(=O)[C@@H]([NH3+])CC([O-])=O)C([O-])=O"],
190+
load_preds_if_possible=False,
191+
)
192+
print(len(r), r[0])

chebifier/ensemble/weighted_majority_ensemble.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def calculate_classwise_weights(self, predicted_classes):
3737
"""
3838
Given the positions of predicted classes in the predictions tensor, assign weights to each class. The
3939
result is two tensors of shape (num_predicted_classes, num_models). The weight for each class is the model_weight
40-
(default: 1) multiplied by the class-specific validation-f1 (default 1).
40+
(default: 1) multiplied by (1 + the class-specific validation-f1 (default 1)).
4141
"""
4242
weights_by_cls = torch.ones(len(predicted_classes), len(self.models))
4343
for j, model in enumerate(self.models):
@@ -51,7 +51,7 @@ def calculate_classwise_weights(self, predicted_classes):
5151
* weights["TP"]
5252
/ (2 * weights["TP"] + weights["FP"] + weights["FN"])
5353
)
54-
weights_by_cls[predicted_classes[cls], j] *= f1
54+
weights_by_cls[predicted_classes[cls], j] *= 1 + f1
5555

5656
print("Calculated model weightings. The average weights are:")
5757
for i, model in enumerate(self.models):

chebifier/prediction_models/base_predictor.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,16 @@ def __init__(
1919
else:
2020
self.classwise_weights = None
2121

22+
self._description = kwargs.get("description", None)
23+
2224
def predict_smiles_list(self, smiles_list: list[str]) -> dict:
2325
raise NotImplementedError
26+
27+
@property
28+
def info_text(self):
29+
if self._description is None:
30+
return "No description is available for this model."
31+
return self._description
32+
33+
def explain_smiles(self, smiles):
34+
return None

0 commit comments

Comments
 (0)