Skip to content

Commit ebc450f

Browse files
committed
Merge branch 'refs/heads/dev' into feature/api_downloadble_models
# Conflicts: # chebifier/cli.py # chebifier/ensemble/base_ensemble.py
2 parents e0b3ca7 + 2c724e7 commit ebc450f

File tree

8 files changed

+723
-155
lines changed

8 files changed

+723
-155
lines changed
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# This workflow will upload a Python Package to PyPI when a release is created
2+
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
3+
4+
# This workflow uses actions that are not certified by GitHub.
5+
# They are provided by a third-party and are governed by
6+
# separate terms of service, privacy policy, and support
7+
# documentation.
8+
9+
name: Upload Python Package
10+
11+
on:
12+
release:
13+
types: [published]
14+
15+
permissions:
16+
contents: read
17+
18+
jobs:
19+
release-build:
20+
runs-on: ubuntu-latest
21+
22+
steps:
23+
- uses: actions/checkout@v4
24+
25+
- uses: actions/setup-python@v5
26+
with:
27+
python-version: "3.12"
28+
29+
- name: Build release distributions
30+
run: |
31+
python -m pip install build
32+
python -m build
33+
34+
- name: Upload distributions
35+
uses: actions/upload-artifact@v4
36+
with:
37+
name: release-dists
38+
path: dist/
39+
40+
pypi-publish:
41+
runs-on: ubuntu-latest
42+
needs:
43+
- release-build
44+
permissions:
45+
# IMPORTANT: this permission is mandatory for trusted publishing
46+
id-token: write
47+
48+
# Dedicated environments with protections for publishing are strongly recommended.
49+
# For more information, see: https://docs.github.com/en/actions/deployment/targeting-different-environments/using-environments-for-deployment#deployment-protection-rules
50+
environment:
51+
name: pypi
52+
# OPTIONAL: uncomment and update to include your PyPI project URL in the deployment status:
53+
# url: https://pypi.org/p/YOURPROJECT
54+
#
55+
# ALTERNATIVE: if your GitHub Release name is the PyPI project version string
56+
# ALTERNATIVE: exactly, uncomment the following line instead:
57+
# url: https://pypi.org/project/YOURPROJECT/${{ github.event.release.name }}
58+
59+
steps:
60+
- name: Retrieve release distributions
61+
uses: actions/download-artifact@v4
62+
with:
63+
name: release-dists
64+
path: dist/
65+
66+
- name: Publish release distributions to PyPI
67+
uses: pypa/gh-action-pypi-publish@release/v1
68+
with:
69+
packages-dir: dist/

README.md

Lines changed: 61 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# python-chebifier
2-
An AI ensemble model for predicting chemical classes.
2+
An AI ensemble model for predicting chemical classes in the ChEBI ontology.
33

44
## Installation
55

@@ -12,6 +12,9 @@ cd python-chebifier
1212
pip install -e .
1313
```
1414

15+
Some dependencies of `chebai-graph` cannot be installed automatically. If you want to use Graph Neural Networks, follow
16+
the instructions in the [chebai-graph repository](https://github.com/ChEB-AI/python-chebai-graph).
17+
1518
## Usage
1619

1720
### Command Line Interface
@@ -23,39 +26,18 @@ The package provides a command-line interface (CLI) for making predictions using
2326
python -m chebifier.cli --help
2427

2528
# Make predictions using a configuration file
26-
python -m chebifier.cli predict example_config.yml --smiles "CC(=O)OC1=CC=CC=C1C(=O)O" "C1=CC=C(C=C1)C(=O)O"
29+
python -m chebifier.cli predict configs/example_config.yml --smiles "CC(=O)OC1=CC=CC=C1C(=O)O" "C1=CC=C(C=C1)C(=O)O"
2730

2831
# Make predictions using SMILES from a file
29-
python -m chebifier.cli predict example_config.yml --smiles-file smiles.txt
32+
python -m chebifier.cli predict configs/example_config.yml --smiles-file smiles.txt
3033
```
3134

3235
### Configuration File
3336

34-
The CLI requires a YAML configuration file that defines the ensemble model. Here's an example:
35-
36-
```yaml
37-
# Example configuration file for Chebifier ensemble model
38-
39-
# Each key in the top-level dictionary is a model name
40-
model1:
41-
# Required: type of model (must be one of the keys in MODEL_TYPES)
42-
type: electra
43-
# Required: name of the model
44-
model_name: electra_model1
45-
# Required: path to the checkpoint file
46-
ckpt_path: /path/to/checkpoint1.ckpt
47-
# Required: path to the target labels file
48-
target_labels_path: /path/to/target_labels1.txt
49-
# Optional: batch size for predictions (default is likely defined in the model)
50-
batch_size: 32
51-
52-
model2:
53-
type: electra
54-
model_name: electra_model2
55-
ckpt_path: /path/to/checkpoint2.ckpt
56-
target_labels_path: /path/to/target_labels2.txt
57-
batch_size: 64
58-
```
37+
The CLI requires a YAML configuration file that defines the ensemble model. An example can be found in `configs/example_config.yml`.
38+
39+
The models and other required files are trained / generated by our [chebai](https://github.com/ChEB-AI/python-chebai) package.
40+
Examples for models can be found on [kaggle](https://www.kaggle.com/datasets/sfluegel/chebai).
5941

6042
### Python API
6143

@@ -77,10 +59,59 @@ smiles_list = ["CC(=O)OC1=CC=CC=C1C(=O)O", "C1=CC=C(C=C1)C(=O)O"]
7759
predictions = ensemble.predict_smiles_list(smiles_list)
7860

7961
# Print results
80-
for smile, prediction in zip(smiles_list, predictions):
81-
print(f"SMILES: {smile}")
62+
for smiles, prediction in zip(smiles_list, predictions):
63+
print(f"SMILES: {smiles}")
8264
if prediction:
8365
print(f"Predicted classes: {prediction}")
8466
else:
8567
print("No predictions")
8668
```
69+
70+
### The ensemble
71+
72+
Given a sample (i.e., a SMILES string) and models $m_1, m_2, \ldots, m_n$, the ensemble works as follows:
73+
1. Get predictions from each model $m_i$ for the sample.
74+
2. For each class $c$, aggregate predictions $p_c^{m_i}$ from all models that made a prediction for that class.
75+
The aggregation happens separately for all positive predictions (i.e., $p_c^{m_i} \geq 0.5$) and all negative predictions
76+
($p_c^{m_i} < 0.5$). If the aggregated value is larger for the positive predictions than for the negative predictions,
77+
the ensemble makes a positive prediction for class $c$:
78+
79+
$$
80+
\text{ensemble}(c) = \begin{cases}
81+
1 & \text{if } \sum_{i: p_c^{m_i} \geq 0.5} [\text{confidence}_c^{m_i} \cdot \text{model_weight}_{m_i} \cdot \text{trust}_c^{m_i}] > \sum_{i: p_c^{m_i} < 0.5} [\text{confidence}_c^{m_i} \cdot \text{model_weight}_{m_i} \cdot \text{trust}_c^{m_i}] \\
82+
0 & \text{otherwise}
83+
\end{cases}
84+
$$
85+
86+
Here, confidence is the model's (self-reported) confidence in its prediction, calculated as
87+
$$
88+
\text{confidence}_c^{m_i} = 2|p_c^{m_i} - 0.5|
89+
$$
90+
For example, if a model makes a positive prediction with $p_c^{m_i} = 0.55$, the confidence is $2|0.55 - 0.5| = 0.1$.
91+
One could say that the model is not very confident in its prediction and very close to switching to a negative prediction.
92+
If another model is very sure about its negative prediction with $p_c^{m_j} = 0.1$, the confidence is $2|0.1 - 0.5| = 0.8$.
93+
Therefore, if in doubt, we are more confident in the negative prediction.
94+
95+
Confidence can be disabled by the `use_confidence` parameter of the predict method (default: True).
96+
97+
The model_weight can be set for each model in the configuration file (default: 1). This is used to favor a certain
98+
model independently of a given class.
99+
Trust is based on the model's performance on a validation set. After training, we evaluate the Machine Learning models
100+
on a validation set for each class. If the `ensemble_type` is set to `wmv-f1`, the trust is calculated as 1 + the F1 score.
101+
If the `ensemble_type` is set to `mv` (the default), the trust is set to 1 for all models.
102+
103+
3. After a decision has been made for each class independently, the consistency of the predictions with regard to the ChEBI hierarchy
104+
and disjointness axioms is checked. This is
105+
done in 3 steps:
106+
- (1) First, the hierarchy is corrected. For each pair of classes $A$ and $B$ where $A$ is a subclass of $B$ (following
107+
the is-a relation in ChEBI), we set the ensemble prediction of $B$ to 1 if the prediction of $A$ is 1. Intuitively
108+
speaking, if we have determined that a molecule belongs to a specific class (e.g., aromatic primary alcohol), it also
109+
belongs to the direct and indirect superclasses (e.g., primary alcohol, aromatic alcohol, alcohol).
110+
- (2) Next, we check for disjointness. This is not specified directly in ChEBI, but in an additional ChEBI module ([chebi-disjoints.owl](https://ftp.ebi.ac.uk/pub/databases/chebi/ontology/)).
111+
We have extracted these disjointness axioms into a CSV file and added some more disjointness axioms ourselves (see
112+
`data>disjoint_chebi.csv` and `data>disjoint_additional.csv`). If two classes $A$ and $B$ are disjoint and we predict
113+
both, we select one of them randomly and set the other to 0.
114+
- (3) Since the second step might have introduced new inconsistencies into the hierarchy, we repeat the first step, but
115+
with a small change. For a pair of classes $A \subseteq B$ with predictions $1$ and $0$, instead of setting $B$ to $1$,
116+
we now set $A$ to $0$. This has the advantage that we cannot introduce new disjointness-inconsistencies and don't have
117+
to repeat step 2.

chebifier/cli.py

Lines changed: 25 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -2,54 +2,47 @@
22
import yaml
33

44
from .model_registry import ENSEMBLES
5+
from chebifier.ensemble.base_ensemble import BaseEnsemble
6+
from chebifier.ensemble.weighted_majority_ensemble import WMVwithPPVNPVEnsemble, WMVwithF1Ensemble
57

68

79
@click.group()
810
def cli():
911
"""Command line interface for Chebifier."""
1012
pass
1113

14+
ENSEMBLES = {
15+
"mv": BaseEnsemble,
16+
"wmv-ppvnpv": WMVwithPPVNPVEnsemble,
17+
"wmv-f1": WMVwithF1Ensemble
18+
}
1219

1320
@cli.command()
14-
@click.argument("config_file", type=click.Path(exists=True))
15-
@click.option("--smiles", "-s", multiple=True, help="SMILES strings to predict")
16-
@click.option(
17-
"--smiles-file",
18-
"-f",
19-
type=click.Path(exists=True),
20-
help="File containing SMILES strings (one per line)",
21-
)
22-
@click.option(
23-
"--output",
24-
"-o",
25-
type=click.Path(),
26-
help="Output file to save predictions (optional)",
27-
)
28-
@click.option(
29-
"--ensemble-type",
30-
"-e",
31-
type=click.Choice(ENSEMBLES.keys()),
32-
default="mv",
33-
help="Type of ensemble to use (default: Majority Voting)",
34-
)
35-
def predict(config_file, smiles, smiles_file, output, ensemble_type):
21+
@click.argument('config_file', type=click.Path(exists=True))
22+
@click.option('--smiles', '-s', multiple=True, help='SMILES strings to predict')
23+
@click.option('--smiles-file', '-f', type=click.Path(exists=True), help='File containing SMILES strings (one per line)')
24+
@click.option('--output', '-o', type=click.Path(), help='Output file to save predictions (optional)')
25+
@click.option('--ensemble-type', '-e', type=click.Choice(ENSEMBLES.keys()), default='mv', help='Type of ensemble to use (default: Majority Voting)')
26+
@click.option("--chebi-version", "-v", type=int, default=241, help="ChEBI version to use for checking consistency (default: 241)")
27+
@click.option("--use-confidence", "-c", is_flag=True, default=True, help="Weight predictions based on how 'confident' a model is in its prediction (default: True)")
28+
def predict(config_file, smiles, smiles_file, output, ensemble_type, chebi_version):
3629
"""Predict ChEBI classes for SMILES strings using an ensemble model.
37-
30+
3831
CONFIG_FILE is the path to a YAML configuration file for the ensemble model.
3932
"""
4033
# Load configuration from YAML file
41-
with open(config_file, "r") as f:
34+
with open(config_file, 'r') as f:
4235
config = yaml.safe_load(f)
43-
36+
4437
# Instantiate ensemble model
45-
ensemble = ENSEMBLES[ensemble_type](config)
46-
38+
ensemble = ENSEMBLES[ensemble_type](config, chebi_version=chebi_version)
39+
4740
# Collect SMILES strings from arguments and/or file
4841
smiles_list = list(smiles)
4942
if smiles_file:
50-
with open(smiles_file, "r") as f:
43+
with open(smiles_file, 'r') as f:
5144
smiles_list.extend([line.strip() for line in f if line.strip()])
52-
45+
5346
if not smiles_list:
5447
click.echo("No SMILES strings provided. Use --smiles or --smiles-file options.")
5548
return
@@ -60,13 +53,8 @@ def predict(config_file, smiles, smiles_file, output, ensemble_type):
6053
if output:
6154
# save as json
6255
import json
63-
64-
with open(output, "w") as f:
65-
json.dump(
66-
{smiles: pred for smiles, pred in zip(smiles_list, predictions)},
67-
f,
68-
indent=2,
69-
)
56+
with open(output, 'w') as f:
57+
json.dump({smiles: pred for smiles, pred in zip(smiles_list, predictions)}, f, indent=2)
7058

7159
else:
7260
# Print results
@@ -78,5 +66,5 @@ def predict(config_file, smiles, smiles_file, output, ensemble_type):
7866
click.echo(" No predictions")
7967

8068

81-
if __name__ == "__main__":
69+
if __name__ == '__main__':
8270
cli()

0 commit comments

Comments
 (0)