Skip to content

Commit abe9e2a

Browse files
authored
Merge branch 'dev' into feature/new-ensemble-models
2 parents bcf96f6 + 22c517c commit abe9e2a

File tree

7 files changed

+258
-71
lines changed

7 files changed

+258
-71
lines changed

README.md

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,21 @@ The `classes_path` is the path to the dataset's `raw/classes.txt` file that cont
7878

7979
## Evaluation
8080

81-
An example for evaluating a model trained on the ontology extension task is given in `tutorials/eval_model_basic.ipynb`.
82-
It takes in the finetuned model as input for performing the evaluation.
81+
You can evaluate a model trained on the ontology extension task in one of two ways:
82+
83+
### 1. Using the Jupyter Notebook
84+
An example notebook is provided at `tutorials/eval_model_basic.ipynb`.
85+
- Load your finetuned model and run the evaluation cells to compute metrics on the test set.
86+
87+
### 2. Using the Lightning CLI
88+
Alternatively, you can evaluate the model via the CLI:
89+
90+
```bash
91+
python -m chebai test --trainer=configs/training/default_trainer.yml --trainer.devices=1 --trainer.num_nodes=1 --ckpt_path=[path-to-finetuned-model] --model=configs/model/electra.yml --model.test_metrics=configs/metrics/micro-macro-f1.yml --data=configs/data/chebi/chebi50.yml --data.init_args.batch_size=32 --data.init_args.num_workers=10 --data.init_args.chebi_version=[chebi-version] --model.pass_loss_kwargs=false --model.criterion=configs/loss/bce.yml --model.criterion.init_args.beta=0.99 --data.init_args.splits_file_path=[path-to-splits-file]
92+
```
93+
94+
> **Note**: It is recommended to use `devices=1` and `num_nodes=1` during testing; multi-device settings use a `DistributedSampler`, which may replicate some samples to maintain equal batch sizes, so using a single device ensures that each sample or batch is evaluated exactly once.
95+
8396

8497
## Cross-validation
8598
You can do inner k-fold cross-validation, i.e., train models on k train-validation splits that all use the same test

chebai/models/ffn.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@ def __init__(
1414
hidden_layers: List[int] = [
1515
1024,
1616
],
17+
use_adam_optimizer: bool = False,
1718
**kwargs,
1819
):
1920
super().__init__(**kwargs)
2021

22+
self.use_adam_optimizer: bool = bool(use_adam_optimizer)
23+
print(f"Using Adam optimizer: {self.use_adam_optimizer}")
24+
2125
layers = []
2226
current_layer_input_size = self.input_dim
2327
for hidden_dim in hidden_layers:
@@ -26,7 +30,6 @@ def __init__(
2630
current_layer_input_size = hidden_dim
2731

2832
layers.append(torch.nn.Linear(current_layer_input_size, self.out_dim))
29-
layers.append(nn.Sigmoid())
3033
self.model = nn.Sequential(*layers)
3134

3235
def _get_prediction_and_labels(self, data, labels, model_output):
@@ -63,6 +66,21 @@ def forward(self, data, **kwargs):
6366
x = data["features"]
6467
return {"logits": self.model(x)}
6568

69+
def configure_optimizers(self, **kwargs) -> torch.optim.Optimizer:
70+
"""
71+
Configures the optimizers.
72+
73+
Args:
74+
**kwargs: Additional keyword arguments.
75+
76+
Returns:
77+
torch.optim.Optimizer: The optimizer.
78+
"""
79+
if self.use_adam_optimizer:
80+
return torch.optim.Adam(self.parameters(), **self.optimizer_kwargs)
81+
82+
return torch.optim.Adamax(self.parameters(), **self.optimizer_kwargs)
83+
6684

6785
class Residual(nn.Module):
6886
"""

chebai/preprocessing/datasets/base.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import os
22
import random
33
from abc import ABC, abstractmethod
4+
from pathlib import Path
45
from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, Union
56

67
import lightning as pl
@@ -76,6 +77,7 @@ def __init__(
7677
label_filter: Optional[int] = None,
7778
balance_after_filter: Optional[float] = None,
7879
num_workers: int = 1,
80+
persistent_workers: bool = True,
7981
chebi_version: int = 200,
8082
inner_k_folds: int = -1, # use inner cross-validation if > 1
8183
fold_index: Optional[int] = None,
@@ -99,6 +101,7 @@ def __init__(
99101
), "Filter balancing requires a filter"
100102
self.balance_after_filter = balance_after_filter
101103
self.num_workers = num_workers
104+
self.persistent_workers: bool = bool(persistent_workers)
102105
self.chebi_version = chebi_version
103106
assert type(inner_k_folds) is int
104107
self.inner_k_folds = inner_k_folds
@@ -363,7 +366,7 @@ def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader
363366
"train",
364367
shuffle=True,
365368
num_workers=self.num_workers,
366-
persistent_workers=True,
369+
persistent_workers=self.persistent_workers,
367370
**kwargs,
368371
)
369372

@@ -382,7 +385,7 @@ def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]
382385
"validation",
383386
shuffle=False,
384387
num_workers=self.num_workers,
385-
persistent_workers=True,
388+
persistent_workers=self.persistent_workers,
386389
**kwargs,
387390
)
388391

@@ -420,10 +423,17 @@ def prepare_data(self, *args, **kwargs) -> None:
420423

421424
self._prepare_data_flag += 1
422425
self._perform_data_preparation(*args, **kwargs)
426+
self._after_prepare_data(*args, **kwargs)
423427

424428
def _perform_data_preparation(self, *args, **kwargs) -> None:
425429
raise NotImplementedError
426430

431+
def _after_prepare_data(self, *args, **kwargs) -> None:
432+
"""
433+
Hook to perform additional pre-processing after pre-processed data is available.
434+
"""
435+
...
436+
427437
def setup(self, *args, **kwargs) -> None:
428438
"""
429439
Setup the data module.
@@ -466,14 +476,17 @@ def _set_processed_data_props(self):
466476
- self._num_of_labels: Number of target labels in the dataset.
467477
- self._feature_vector_size: Maximum feature vector length across all data points.
468478
"""
469-
data_pt = torch.load(
470-
os.path.join(self.processed_dir, self.processed_file_names_dict["data"]),
471-
weights_only=False,
479+
pt_file_path = os.path.join(
480+
self.processed_dir, self.processed_file_names_dict["data"]
472481
)
482+
data_pt = torch.load(pt_file_path, weights_only=False)
473483

474484
self._num_of_labels = len(data_pt[0]["labels"])
475485
self._feature_vector_size = max(len(d["features"]) for d in data_pt)
476486

487+
print(
488+
f"Number of samples in encoded data ({pt_file_path}): {len(data_pt)} samples"
489+
)
477490
print(f"Number of labels for loaded data: {self._num_of_labels}")
478491
print(f"Feature vector size: {self._feature_vector_size}")
479492

@@ -747,6 +760,7 @@ def __init__(
747760
)
748761
self.apply_label_filter = apply_label_filter
749762
self.apply_id_filter = apply_id_filter
763+
self._data_pkl_filename: str = "data.pkl"
750764

751765
@staticmethod
752766
def _validate_splits_file_path(splits_file_path: Optional[str]) -> Optional[str]:
@@ -885,6 +899,21 @@ def save_processed(self, data: pd.DataFrame, filename: str) -> None:
885899
"""
886900
pd.to_pickle(data, open(os.path.join(self.processed_dir_main, filename), "wb"))
887901

902+
def get_processed_pickled_df_file(self, filename: str) -> Optional[pd.DataFrame]:
903+
"""
904+
Gets the processed dataset pickle file.
905+
906+
Args:
907+
filename (str): The filename for the pickle file.
908+
909+
Returns:
910+
pd.DataFrame: The processed dataset as a DataFrame.
911+
"""
912+
file_path = Path(self.processed_dir_main) / filename
913+
if file_path.exists():
914+
return pd.read_pickle(file_path)
915+
return None
916+
888917
# ------------------------------ Phase: Setup data -----------------------------------
889918
def setup_processed(self) -> None:
890919
"""
@@ -923,7 +952,9 @@ def _get_data_size(input_file_path: str) -> int:
923952
int: The size of the data.
924953
"""
925954
with open(input_file_path, "rb") as f:
926-
return len(pd.read_pickle(f))
955+
df = pd.read_pickle(f)
956+
print(f"Processed data size ({input_file_path}): {len(df)} rows")
957+
return len(df)
927958

928959
@abstractmethod
929960
def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]:
@@ -1260,7 +1291,7 @@ def processed_main_file_names_dict(self) -> dict:
12601291
dict: A dictionary mapping dataset key to their respective file names.
12611292
For example, {"data": "data.pkl"}.
12621293
"""
1263-
return {"data": "data.pkl"}
1294+
return {"data": self._data_pkl_filename}
12641295

12651296
@property
12661297
def raw_file_names(self) -> List[str]:

0 commit comments

Comments
 (0)