Skip to content

Commit 3a67087

Browse files
committed
Merge branch 'dev' into out_dim_dynamic
2 parents ed50c8d + ff03d6f commit 3a67087

File tree

14 files changed

+999
-474
lines changed

14 files changed

+999
-474
lines changed

.github/workflows/test.yml

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,19 +9,30 @@ jobs:
99
strategy:
1010
fail-fast: false
1111
matrix:
12-
python-version: ["3.9", "3.10", "3.11"]
12+
python-version: ["3.9", "3.10", "3.11", "3.12"]
1313

1414
steps:
1515
- uses: actions/checkout@v4
16+
1617
- name: Set up Python ${{ matrix.python-version }}
1718
uses: actions/setup-python@v5
1819
with:
1920
python-version: ${{ matrix.python-version }}
21+
2022
- name: Install dependencies
2123
run: |
2224
python -m pip install --upgrade pip
2325
python -m pip install --upgrade pip setuptools wheel
2426
python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
2527
python -m pip install -e .
26-
- name: Display Python version
27-
run: python -m unittest discover -s tests/unit
28+
29+
- name: Display Python & Installed Packages
30+
run: |
31+
python --version
32+
pip freeze
33+
34+
- name: Run Unit Tests
35+
run: python -m unittest discover -s tests/unit -v
36+
env:
37+
ACTIONS_STEP_DEBUG: true # Enable debug logs
38+
ACTIONS_RUNNER_DEBUG: true # Additional debug logs from Github Actions itself

chebai/cli.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,6 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser):
5858
parser.link_arguments(
5959
"data.num_of_labels", "trainer.callbacks.init_args.num_labels"
6060
)
61-
parser.link_arguments(
62-
"data", "model.init_args.criterion.init_args.data_extractor"
63-
)
6461

6562
@staticmethod
6663
def subcommands() -> Dict[str, Set[str]]:

chebai/loss/bce_weighted.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@
55
import torch
66

77
from chebai.preprocessing.datasets.base import XYBaseDataModule
8+
from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor
89
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
910

1011

1112
class BCEWeighted(torch.nn.BCEWithLogitsLoss):
1213
"""
1314
BCEWithLogitsLoss with weights automatically computed according to the beta parameter.
15+
If beta is None or data_extractor is None, the loss is unweighted.
1416
15-
This class computes weights based on the formula from the paper:
17+
This class computes weights based on the formula from the paper by Cui et al. (2019):
1618
https://openaccess.thecvf.com/content_CVPR_2019/papers/Cui_Class-Balanced_Loss_Based_on_Effective_Number_of_Samples_CVPR_2019_paper.pdf
1719
1820
Args:
@@ -24,13 +26,17 @@ def __init__(
2426
self,
2527
beta: Optional[float] = None,
2628
data_extractor: Optional[XYBaseDataModule] = None,
29+
**kwargs,
2730
):
2831
self.beta = beta
2932
if isinstance(data_extractor, LabeledUnlabeledMixed):
3033
data_extractor = data_extractor.labeled
3134
self.data_extractor = data_extractor
32-
33-
super().__init__()
35+
assert (
36+
isinstance(self.data_extractor, _ChEBIDataExtractor)
37+
or self.data_extractor is None
38+
)
39+
super().__init__(**kwargs)
3440

3541
def set_pos_weight(self, input: torch.Tensor) -> None:
3642
"""
@@ -50,6 +56,9 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
5056
)
5157
and self.pos_weight is None
5258
):
59+
print(
60+
f"Computing loss-weights based on v{self.data_extractor.chebi_version} dataset (beta={self.beta})"
61+
)
5362
complete_data = pd.concat(
5463
[
5564
pd.read_pickle(
@@ -75,7 +84,9 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
7584
[w / mean for w in weights], device=input.device
7685
)
7786

78-
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
87+
def forward(
88+
self, input: torch.Tensor, target: torch.Tensor, **kwargs
89+
) -> torch.Tensor:
7990
"""
8091
Forward pass for the loss calculation.
8192

0 commit comments

Comments
 (0)