Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
a91746f
add documentation for pubchem kmeans, update tokens.txt
Sep 18, 2024
30d194c
update SMILES tokens
Sep 18, 2024
af54a11
replace tnorm by fuzzy implication names
Sep 19, 2024
31c8107
fix argument linking
Sep 19, 2024
04226a5
fix semantic loss
Sep 19, 2024
094ade2
add kleene-dienes implication and sigmoidal implications
Sep 20, 2024
74a4e12
fix reichenbach abbreviation rb->rc
Sep 20, 2024
f751bc5
fix fuzzy loss evaluation
Sep 24, 2024
4e4a8fb
add epoch-dependent weighting of semantic terms
Sep 25, 2024
5fa3ac5
fix tensor handling
Sep 25, 2024
8acf6ba
fix dynamic loss weights, log weighted and unweighted loss components
Sep 25, 2024
72b0011
no fuzzy loss for epoch<=10
Sep 26, 2024
5d80272
actually set fuzzy loss to 0 for epoch<=10
Sep 26, 2024
e57aa8d
fix bce loss
Sep 27, 2024
94039c0
remove skipping first 10 epochs
Sep 30, 2024
bdef653
fix fuzzy loss (now passing the weighted loss components)
Oct 1, 2024
7f1c468
add epoch to analyse_sem output
Oct 1, 2024
0264c85
add right-aggregated macro-FNR
Oct 1, 2024
cfda8c6
download ckpt without returning model (return ckpt path instead), imp…
Oct 2, 2024
2a7a10f
fix sigmoidal implication
Oct 4, 2024
a85498c
add goedel implication
Oct 4, 2024
eef32d2
fix implication loss signature
Oct 7, 2024
6f6c6a0
fix goedel loss
Oct 7, 2024
d83957c
add error messages
Oct 7, 2024
e4b7076
fix sigmoidal implication
Oct 8, 2024
9d619a7
add ap metric, results by class to analyse_sem
Oct 9, 2024
cc1bba7
fix goedel loss
Oct 9, 2024
235abdb
fix pos scalar typehint
Oct 10, 2024
e892967
add epsilon to consequent (for balanced loss with k < 1)
Oct 10, 2024
5fed371
merge dev into feature-fuzzy-loss
Oct 10, 2024
f34a5ad
add parameters to epoch-dependent weighting
Oct 10, 2024
436c897
disable strict checkpoint loading
Oct 10, 2024
e7c9ff8
fix checkpoint loading
Oct 10, 2024
7f7695d
add raw file names (temporary fix), add pubchem data to fuzzy eval,
Oct 14, 2024
33c2d64
efficiency, minor fixes, changed paths
Oct 17, 2024
0e32978
add elementwise multiplicative fuzzy loss
Oct 21, 2024
2037c20
clean up analyse_sem.py
Oct 23, 2024
a16a065
fix fuzzy loss mean aggregation
Oct 23, 2024
70d0f29
make fuzzy loss implementation more efficient
Oct 23, 2024
93a4aae
fix device
Oct 23, 2024
fbd2c65
add max aggregation for fuzzy loss
Oct 24, 2024
acb05d7
adapt evaluation to new fuzzy loss
Oct 25, 2024
ebda76c
add binary implication
Oct 25, 2024
71a3f30
fix disjointness, binary implication
Oct 28, 2024
4b354d3
add sigmoid
Oct 28, 2024
45c54f4
add log- and mean aggregation
Nov 4, 2024
62a49eb
add optional detach from gradients
Nov 4, 2024
7639179
Merge branch 'dev' into feature-fuzzy-loss
sfluegel05 Dec 19, 2024
e7bd80a
fix: ignore empty list in evaluation
Dec 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions chebai/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,10 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser):
parser.link_arguments(
"data", "model.init_args.criterion.init_args.data_extractor"
)
parser.link_arguments(
"data.init_args.chebi_version",
"model.init_args.criterion.init_args.data_extractor.init_args.chebi_version",
)

@staticmethod
def subcommands() -> Dict[str, Set[str]]:
Expand Down
19 changes: 15 additions & 4 deletions chebai/loss/bce_weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@
import torch

from chebai.preprocessing.datasets.base import XYBaseDataModule
from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed


class BCEWeighted(torch.nn.BCEWithLogitsLoss):
"""
BCEWithLogitsLoss with weights automatically computed according to the beta parameter.
If beta is None or data_extractor is None, the loss is unweighted.

This class computes weights based on the formula from the paper:
This class computes weights based on the formula from the paper by Cui et al. (2019):
https://openaccess.thecvf.com/content_CVPR_2019/papers/Cui_Class-Balanced_Loss_Based_on_Effective_Number_of_Samples_CVPR_2019_paper.pdf

Args:
Expand All @@ -24,13 +26,17 @@ def __init__(
self,
beta: Optional[float] = None,
data_extractor: Optional[XYBaseDataModule] = None,
**kwargs,
):
self.beta = beta
if isinstance(data_extractor, LabeledUnlabeledMixed):
data_extractor = data_extractor.labeled
self.data_extractor = data_extractor

super().__init__()
assert (
isinstance(self.data_extractor, _ChEBIDataExtractor)
or self.data_extractor is None
)
super().__init__(**kwargs)

def set_pos_weight(self, input: torch.Tensor) -> None:
"""
Expand All @@ -50,6 +56,9 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
)
and self.pos_weight is None
):
print(
f"Computing loss-weights based on v{self.data_extractor.chebi_version} dataset (beta={self.beta})"
)
complete_data = pd.concat(
[
pd.read_pickle(
Expand All @@ -75,7 +84,9 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
[w / mean for w in weights], device=input.device
)

def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
def forward(
self, input: torch.Tensor, target: torch.Tensor, **kwargs
) -> torch.Tensor:
"""
Forward pass for the loss calculation.

Expand Down
Loading
Loading