Skip to content

Commit 8e56851

Browse files
committed
merge from dev
2 parents 0a39749 + 21e4378 commit 8e56851

File tree

10 files changed

+96
-46
lines changed

10 files changed

+96
-46
lines changed

README.md

Lines changed: 60 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,77 @@
11

2+
# ChEB-AI Graph
3+
4+
Graph-based models for molecular property prediction and ontology classification, built on top of the [`python-chebai`](https://github.com/ChEB-AI/python-chebai) codebase.
5+
6+
27

38
## Installation
49

5-
Some requirements may not be installed successfully automatically.
6-
To install the `torch-` libraries, use
10+
To install this repository, download it and run
711

8-
`pip install torch-${lib} -f https://data.pyg.org/whl/torch-2.1.0+${CUDA}.html`
12+
```bash
13+
pip install .
14+
```
915

10-
where `${lib}` is either `scatter`, `geometric`, `sparse` or `cluster`, and
11-
`${CUDA}` is either `cpu`, `cu118` or `cu121` (depending on your system, see e.g.
12-
[torch-geometric docs](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html))
16+
or install it directly with
17+
```bash
18+
pip install git+https://github.com/ChEB-AI/python-chebai-graph.git
19+
```
1320

21+
The dependencies `torch`, `torch_geometric` and `torch-sparse` cannot be installed automatically.
1422

15-
## Commands
23+
Use the following command:
1624

17-
For training, config files from the `python-chebai` and `python-chebai-graph` repositories can be combined. This requires that you download the [source code of python-chebai](https://github.com/ChEB-AI/python-chebai). Make sure that you are in the right folder and know the relative path to the other repository.
25+
```bash
26+
pip install torch torch_scatter torch_geometric -f https://data.pyg.org/whl/torch-${TORCH}+${CUDA}.html
27+
```
1828

19-
We recommend the following setup:
29+
Replace:
30+
- `${TORCH}` with a PyTorch version (e.g., `2.6.0`; for later versions, check first if they are compatible with torch_scatter and torch_geometric)
31+
- `${CUDA}` with e.g. `cpu`, `cu118`, or `cu121` depending on your system and CUDA version
2032

21-
my_projects
22-
python-chebai
23-
chebai
24-
configs
25-
data
26-
...
27-
python-chebai-graph
28-
chebai_graph
29-
configs
30-
...
33+
If you already have `torch` installed, make sure that `torch_scatter` and `torch_geometric` are compatible with your
34+
PyTorch version and are installed with the same CUDA version.
3135

32-
If you run the command from the `python-chebai` directory, you can use the same data for both chebai- and chebai-graph-models (e.g., Transformers and GNNs).
33-
Then you have to use `{path-to-chebai} -> .` and `{path-to-chebai-graph} -> ../python-chebai-graph`.
36+
For a full list of currently available PyTorch versions and CUDA compatibility, please refer to libraries' official documentation:
37+
- [torch](https://pytorch.org/get-started/locally/)
38+
- [torch_geometric](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html#installation)
39+
- [torch-scatter](https://github.com/rusty1s/pytorch_scatter)
3440

35-
Pretraining on a atom / bond masking task with PubChem data (feature-branch):
36-
```
37-
python3 -m chebai fit --model={path-to-chebai-graph}/configs/model/gnn_resgated_pretrain.yml --data={path-to-chebai-graph}/configs/data/pubchem_graph.yml --trainer={path-to-chebai}/configs/training/pretraining_trainer.yml
41+
_Note for developers_: If you want to install the package in editable mode, use the following command instead:
42+
43+
```bash
44+
pip install -e .
3845
```
3946

40-
Training on the ontology prediction task (here for ChEBI50, v231, 200 epochs)
47+
## Recommended Folder Structure
48+
49+
ChEB-AI Graph is not a standalone library. Instead, it provides additional models and datasets for [`python-chebai`](https://github.com/ChEB-AI/python-chebai).
50+
The training relies on config files that are located either in `python-chebai` or in this repository.
51+
52+
Therefore, for training, we recommend to clone both repositories into a common parent directory. For instance, your project can look like this:
53+
4154
```
42-
python3 -m chebai fit --trainer={path-to-chebai}/configs/training/default_trainer.yml --trainer.callbacks={path-to-chebai}/configs/training/default_callbacks.yml --model={path-to-chebai-graph}/configs/model/gnn_res_gated.yml --model.train_metrics={path-to-chebai}/configs/metrics/micro-macro-f1.yml --model.test_metrics={path-to-chebai}/configs/metrics/micro-macro-f1.yml --model.val_metrics={path-to-chebai}/configs/metrics/micro-macro-f1.yml --data={path-to-chebai-graph}/configs/data/chebi50_graph_properties.yml --model.criterion=c{path-to-chebai}/onfigs/loss/bce.yml --data.init_args.batch_size=40 --trainer.logger.init_args.name=chebi50_bce_unweighted_resgatedgraph --data.init_args.num_workers=12 --model.pass_loss_kwargs=false --data.init_args.chebi_version=231 --trainer.min_epochs=200 --trainer.max_epochs=200
55+
my_projects/
56+
├── python-chebai/
57+
│ ├── chebai/
58+
│ ├── configs/
59+
│ └── ...
60+
└── python-chebai-graph/
61+
├── chebai_graph/
62+
├── configs/
63+
└── ...
64+
```
65+
66+
## Training & Pretraining
67+
68+
### Ontology Prediction
69+
70+
71+
This example command trains a Residual Gated Graph Convolutional Network on the ChEBI50 dataset (see [wiki](https://github.com/ChEB-AI/python-chebai/wiki/Data-Management)).
72+
The dataset has a customizable list of properties for atoms, bonds and molecules that are added to the graph.
73+
The list can be found in the `configs/data/chebi50_graph_properties.yml` file.
74+
75+
```bash
76+
python -m chebai fit --trainer=configs/training/default_trainer.yml --trainer.logger=configs/training/csv_logger.yml --model=../python-chebai-graph/configs/model/gnn_res_gated.yml --model.train_metrics=configs/metrics/micro-macro-f1.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --model.val_metrics=configs/metrics/micro-macro-f1.yml --data=../python-chebai-graph/configs/data/chebi50_graph_properties.yml --data.init_args.batch_size=128 --trainer.accumulate_grad_batches=4 --data.init_args.num_workers=10 --model.pass_loss_kwargs=false --data.init_args.chebi_version=241 --trainer.min_epochs=200 --trainer.max_epochs=200 --model.criterion=configs/loss/bce.yml
4377
```

chebai_graph/preprocessing/bin/Aromaticity/indices.txt

Whitespace-only changes.

chebai_graph/preprocessing/bin/AtomType/indices.txt

Whitespace-only changes.

chebai_graph/preprocessing/bin/BondType/indices.txt

Whitespace-only changes.

chebai_graph/preprocessing/bin/FormalCharge/indices_one_hot.txt

Whitespace-only changes.

chebai_graph/preprocessing/bin/MoleculeNumRings/indices_one_hot.txt

Whitespace-only changes.

chebai_graph/preprocessing/datasets/chebi.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,11 +139,14 @@ def get_property_path(self, property: MolecularProperty):
139139
f"{property.name}_{property.encoder.name}.pt",
140140
)
141141

142-
def setup(self, **kwargs):
143-
super().setup(keep_reader=True, **kwargs)
144-
self._setup_properties()
142+
def _after_setup(self, **kwargs):
143+
"""
144+
Finalize the setup process after ensuring the processed data is available.
145145
146-
self.reader.on_finish()
146+
This method performs post-setup tasks like finalizing the reader and setting internal properties.
147+
"""
148+
self._setup_properties()
149+
super()._after_setup(**kwargs)
147150

148151
def _merge_props_into_base(self, row):
149152
geom_data = row["features"]

chebai_graph/preprocessing/property_encoder.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,11 +37,13 @@ class IndexEncoder(PropertyEncoder):
3737
def __init__(self, property, indices_dir=None, **kwargs):
3838
super().__init__(property, **kwargs)
3939
if indices_dir is None:
40-
indices_dir = os.path.dirname(__file__)
40+
indices_dir = os.path.dirname(inspect.getfile(self.__class__))
4141
self.dirname = indices_dir
4242
# load already existing cache
4343
with open(self.index_path, "r") as pk:
44-
self.cache = [x.strip() for x in pk]
44+
self.cache: dict[str, int] = {
45+
token.strip(): idx for idx, token in enumerate(pk)
46+
}
4547
self.index_length_start = len(self.cache)
4648
self.offset = 0
4749

@@ -65,19 +67,33 @@ def index_path(self):
6567

6668
def on_finish(self):
6769
"""Save cache"""
68-
with open(self.index_path, "w") as pk:
69-
new_length = len(self.cache) - self.index_length_start
70-
pk.writelines([f"{c}\n" for c in self.cache])
71-
print(
72-
f"saved index of property {self.property.name} to {self.index_path}, "
73-
f"index length: {len(self.cache)} (new: {new_length})"
74-
)
70+
total_tokens = len(self.cache)
71+
if total_tokens > self.index_length_start:
72+
print("New tokens added to the cache, Saving them to index token file.....")
73+
74+
assert sys.version_info >= (
75+
3,
76+
7,
77+
), "This code requires Python 3.7 or higher."
78+
# For python 3.7+, the standard dict type preserves insertion order, and is iterated over in same order
79+
# https://docs.python.org/3/whatsnew/3.7.html#summary-release-highlights
80+
# https://mail.python.org/pipermail/python-dev/2017-December/151283.html
81+
new_tokens = list(islice(self.cache, self.index_length_start, total_tokens))
82+
83+
with open(self.index_path, "a") as pk:
84+
pk.writelines([f"{c}\n" for c in new_tokens])
85+
print(
86+
f"New {len(new_tokens)} tokens append to index of property {self.property.name} to {self.index_path}..."
87+
)
88+
print(
89+
f"Now, the total length of the index of property {self.property.name} is {total_tokens}"
90+
)
7591

7692
def encode(self, token):
7793
"""Returns a unique number for each token, automatically adds new tokens to the cache."""
7894
if not str(token) in self.cache:
79-
self.cache.append(str(token))
80-
return torch.tensor([self.cache.index(str(token)) + self.offset])
95+
self.cache[(str(token))] = len(self.cache)
96+
return torch.tensor([self.cache[str(token)] + self.offset])
8197

8298

8399
class OneHotEncoder(IndexEncoder):

chebai_graph/preprocessing/reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from chebai_graph.preprocessing.collate import GraphCollator
1515

1616

17-
class GraphPropertyReader(dr.ChemDataReader):
17+
class GraphPropertyReader(dr.DataReader):
1818
COLLATOR = GraphCollator
1919

2020
def __init__(

pyproject.toml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,7 @@ authors = [
66
{ name = "Martin Glauer", email = "[email protected]" }
77
]
88
dependencies = [
9-
"torch_geometric",
10-
"torch-scatter",
11-
"torch-sparse",
12-
"torch-cluster",
9+
"chebai",
1310
"descriptastorus"
1411
]
1512

0 commit comments

Comments
 (0)