Skip to content

Commit 13afc28

Browse files
committed
Merge branch 'dev' into fix/avoid_iterrows
2 parents 3785bb5 + 06a1869 commit 13afc28

File tree

16 files changed

+355
-59
lines changed

16 files changed

+355
-59
lines changed

.github/workflows/test.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ jobs:
99
strategy:
1010
fail-fast: false
1111
matrix:
12-
python-version: ["3.9", "3.10", "3.11", "3.12"]
12+
python-version: ["3.10", "3.11", "3.12"]
1313

1414
steps:
1515
- uses: actions/checkout@v4
@@ -24,7 +24,7 @@ jobs:
2424
python -m pip install --upgrade pip
2525
python -m pip install --upgrade pip setuptools wheel
2626
python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
27-
python -m pip install -e .
27+
python -m pip install -e .[dev]
2828
2929
- name: Display Python & Installed Packages
3030
run: |

README.md

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,21 @@ The `classes_path` is the path to the dataset's `raw/classes.txt` file that cont
7878

7979
## Evaluation
8080

81-
An example for evaluating a model trained on the ontology extension task is given in `tutorials/eval_model_basic.ipynb`.
82-
It takes in the finetuned model as input for performing the evaluation.
81+
You can evaluate a model trained on the ontology extension task in one of two ways:
82+
83+
### 1. Using the Jupyter Notebook
84+
An example notebook is provided at `tutorials/eval_model_basic.ipynb`.
85+
- Load your finetuned model and run the evaluation cells to compute metrics on the test set.
86+
87+
### 2. Using the Lightning CLI
88+
Alternatively, you can evaluate the model via the CLI:
89+
90+
```bash
91+
python -m chebai test --trainer=configs/training/default_trainer.yml --trainer.devices=1 --trainer.num_nodes=1 --ckpt_path=[path-to-finetuned-model] --model=configs/model/electra.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --data=configs/data/chebi/chebi50.yml --data.init_args.batch_size=32 --data.init_args.num_workers=10 --data.init_args.chebi_version=[chebi-version] --model.pass_loss_kwargs=false --model.criterion=configs/loss/bce.yml --model.criterion.init_args.beta=0.99 --data.init_args.splits_file_path=[path-to-splits-file]
92+
```
93+
94+
> **Note**: It is recommended to use `devices=1` and `num_nodes=1` during testing; multi-device settings use a `DistributedSampler`, which may replicate some samples to maintain equal batch sizes, so using a single device ensures that each sample or batch is evaluated exactly once.
95+
8396

8497
## Cross-validation
8598
You can do inner k-fold cross-validation, i.e., train models on k train-validation splits that all use the same test

chebai/loggers/custom.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from datetime import datetime
33
from typing import List, Literal, Optional, Union
44

5-
import wandb
65
from lightning.fabric.utilities.types import _PATH
76
from lightning.pytorch.callbacks import ModelCheckpoint
87
from lightning.pytorch.loggers import WandbLogger
@@ -105,6 +104,8 @@ def set_fold(self, fold: int) -> None:
105104
Args:
106105
fold (int): Cross-validation fold number.
107106
"""
107+
import wandb
108+
108109
if fold != self._fold:
109110
self._fold = fold
110111
# Start new experiment

chebai/loss/bce_weighted.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
from chebai.preprocessing.datasets.base import XYBaseDataModule
77
from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor
8-
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
98

109

1110
class BCEWeighted(torch.nn.BCEWithLogitsLoss):
@@ -27,6 +26,8 @@ def __init__(
2726
data_extractor: Optional[XYBaseDataModule] = None,
2827
**kwargs,
2928
):
29+
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
30+
3031
self.beta = beta
3132
if isinstance(data_extractor, LabeledUnlabeledMixed):
3233
data_extractor = data_extractor.labeled

chebai/loss/semantic.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22
import math
33
import os
44
import pickle
5-
from typing import List, Literal, Union
5+
from typing import TYPE_CHECKING, List, Literal, Union
66

77
import torch
88

99
from chebai.loss.bce_weighted import BCEWeighted
1010
from chebai.preprocessing.datasets.base import XYBaseDataModule
1111
from chebai.preprocessing.datasets.chebi import ChEBIOver100, _ChEBIDataExtractor
12-
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
12+
13+
if TYPE_CHECKING:
14+
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
1315

1416

1517
class ImplicationLoss(torch.nn.Module):
@@ -68,6 +70,8 @@ def __init__(
6870
multiply_with_base_loss: bool = True,
6971
no_grads: bool = False,
7072
):
73+
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
74+
7175
super().__init__()
7276
# automatically choose labeled subset for implication filter in case of mixed dataset
7377
if isinstance(data_extractor, LabeledUnlabeledMixed):
@@ -338,7 +342,7 @@ class DisjointLoss(ImplicationLoss):
338342
def __init__(
339343
self,
340344
path_to_disjointness: str,
341-
data_extractor: Union[_ChEBIDataExtractor, LabeledUnlabeledMixed],
345+
data_extractor: Union[_ChEBIDataExtractor, "LabeledUnlabeledMixed"],
342346
base_loss: torch.nn.Module = None,
343347
disjoint_loss_weight: float = 100,
344348
**kwargs,

chebai/models/ffn.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ class FFN(ChebaiBaseNet):
1111

1212
def __init__(
1313
self,
14-
input_size: int,
1514
hidden_layers: List[int] = [
1615
1024,
1716
],
@@ -20,7 +19,7 @@ def __init__(
2019
super().__init__(**kwargs)
2120

2221
layers = []
23-
current_layer_input_size = input_size
22+
current_layer_input_size = self.input_dim
2423
for hidden_dim in hidden_layers:
2524
layers.append(MLPBlock(current_layer_input_size, hidden_dim))
2625
layers.append(Residual(MLPBlock(hidden_dim, hidden_dim)))

chebai/preprocessing/datasets/base.py

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,22 @@
11
import os
22
import random
33
from abc import ABC, abstractmethod
4-
from typing import Any, Dict, Generator, List, Optional, Tuple, Union
4+
from pathlib import Path
5+
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union
56

67
import lightning as pl
7-
import networkx as nx
88
import pandas as pd
99
import torch
1010
import tqdm
11-
from iterstrat.ml_stratifiers import (
12-
MultilabelStratifiedKFold,
13-
MultilabelStratifiedShuffleSplit,
14-
)
1511
from lightning.pytorch.core.datamodule import LightningDataModule
1612
from lightning_utilities.core.rank_zero import rank_zero_info
17-
from sklearn.model_selection import StratifiedShuffleSplit
1813
from torch.utils.data import DataLoader
1914

2015
from chebai.preprocessing import reader as dr
2116

17+
if TYPE_CHECKING:
18+
import networkx as nx
19+
2220

2321
class XYBaseDataModule(LightningDataModule):
2422
"""
@@ -419,10 +417,17 @@ def prepare_data(self, *args, **kwargs) -> None:
419417

420418
self._prepare_data_flag += 1
421419
self._perform_data_preparation(*args, **kwargs)
420+
self._after_prepare_data(*args, **kwargs)
422421

423422
def _perform_data_preparation(self, *args, **kwargs) -> None:
424423
raise NotImplementedError
425424

425+
def _after_prepare_data(self, *args, **kwargs) -> None:
426+
"""
427+
Hook to perform additional pre-processing after pre-processed data is available.
428+
"""
429+
...
430+
426431
def setup(self, *args, **kwargs) -> None:
427432
"""
428433
Setup the data module.
@@ -464,14 +469,17 @@ def _set_processed_data_props(self):
464469
- self._num_of_labels: Number of target labels in the dataset.
465470
- self._feature_vector_size: Maximum feature vector length across all data points.
466471
"""
467-
data_pt = torch.load(
468-
os.path.join(self.processed_dir, self.processed_file_names_dict["data"]),
469-
weights_only=False,
472+
pt_file_path = os.path.join(
473+
self.processed_dir, self.processed_file_names_dict["data"]
470474
)
475+
data_pt = torch.load(pt_file_path, weights_only=False)
471476

472477
self._num_of_labels = len(data_pt[0]["labels"])
473478
self._feature_vector_size = max(len(d["features"]) for d in data_pt)
474479

480+
print(
481+
f"Number of samples in encoded data ({pt_file_path}): {len(data_pt)} samples"
482+
)
475483
print(f"Number of labels for loaded data: {self._num_of_labels}")
476484
print(f"Feature vector size: {self._feature_vector_size}")
477485

@@ -734,6 +742,7 @@ def __init__(
734742
self.splits_file_path = self._validate_splits_file_path(
735743
kwargs.get("splits_file_path", None)
736744
)
745+
self._data_pkl_filename: str = "data.pkl"
737746

738747
@staticmethod
739748
def _validate_splits_file_path(splits_file_path: Optional[str]) -> Optional[str]:
@@ -818,7 +827,7 @@ def _download_required_data(self) -> str:
818827
pass
819828

820829
@abstractmethod
821-
def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
830+
def _extract_class_hierarchy(self, data_path: str) -> "nx.DiGraph":
822831
"""
823832
Extracts the class hierarchy from the data.
824833
Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with fields/terms from
@@ -833,7 +842,7 @@ def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph:
833842
pass
834843

835844
@abstractmethod
836-
def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
845+
def _graph_to_raw_dataset(self, graph: "nx.DiGraph") -> pd.DataFrame:
837846
"""
838847
Converts the graph to a raw dataset.
839848
Uses the graph created by `_extract_class_hierarchy` method to extract the
@@ -848,7 +857,7 @@ def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame:
848857
pass
849858

850859
@abstractmethod
851-
def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List:
860+
def select_classes(self, g: "nx.DiGraph", *args, **kwargs) -> List:
852861
"""
853862
Selects classes from the dataset based on a specified criteria.
854863
@@ -872,6 +881,21 @@ def save_processed(self, data: pd.DataFrame, filename: str) -> None:
872881
"""
873882
pd.to_pickle(data, open(os.path.join(self.processed_dir_main, filename), "wb"))
874883

884+
def get_processed_pickled_df_file(self, filename: str) -> Optional[pd.DataFrame]:
885+
"""
886+
Gets the processed dataset pickle file.
887+
888+
Args:
889+
filename (str): The filename for the pickle file.
890+
891+
Returns:
892+
pd.DataFrame: The processed dataset as a DataFrame.
893+
"""
894+
file_path = Path(self.processed_dir_main) / filename
895+
if file_path.exists():
896+
return pd.read_pickle(file_path)
897+
return None
898+
875899
# ------------------------------ Phase: Setup data -----------------------------------
876900
def setup_processed(self) -> None:
877901
"""
@@ -910,7 +934,9 @@ def _get_data_size(input_file_path: str) -> int:
910934
int: The size of the data.
911935
"""
912936
with open(input_file_path, "rb") as f:
913-
return len(pd.read_pickle(f))
937+
df = pd.read_pickle(f)
938+
print(f"Processed data size ({input_file_path}): {len(df)} rows")
939+
return len(df)
914940

915941
@abstractmethod
916942
def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]:
@@ -1023,6 +1049,9 @@ def get_test_split(
10231049
Raises:
10241050
ValueError: If the DataFrame does not contain a column named "labels".
10251051
"""
1052+
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
1053+
from sklearn.model_selection import StratifiedShuffleSplit
1054+
10261055
print("Get test data split")
10271056

10281057
labels_list = df["labels"].tolist()
@@ -1060,6 +1089,12 @@ def get_train_val_splits_given_test(
10601089
and validation DataFrames. The keys are the names of the train and validation sets, and the values
10611090
are the corresponding DataFrames.
10621091
"""
1092+
from iterstrat.ml_stratifiers import (
1093+
MultilabelStratifiedKFold,
1094+
MultilabelStratifiedShuffleSplit,
1095+
)
1096+
from sklearn.model_selection import StratifiedShuffleSplit
1097+
10631098
print("Split dataset into train / val with given test set")
10641099

10651100
test_ids = test_df["ident"].tolist()
@@ -1217,7 +1252,7 @@ def processed_main_file_names_dict(self) -> dict:
12171252
dict: A dictionary mapping dataset key to their respective file names.
12181253
For example, {"data": "data.pkl"}.
12191254
"""
1220-
return {"data": "data.pkl"}
1255+
return {"data": self._data_pkl_filename}
12211256

12221257
@property
12231258
def raw_file_names(self) -> List[str]:

0 commit comments

Comments
 (0)