Skip to content

Commit 3851ee8

Browse files
author
sfluegel
committed
add argument linking for data_extractor, propagation from semantic loss to base loss
1 parent 8efa2c1 commit 3851ee8

File tree

6 files changed

+29
-10
lines changed

6 files changed

+29
-10
lines changed

chebai/cli.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser):
1919
parser.link_arguments(
2020
"model.init_args.out_dim", "trainer.callbacks.init_args.num_labels"
2121
)
22+
parser.link_arguments(
23+
"data", "model.init_args.criterion.init_args.data_extractor"
24+
)
2225

2326
@staticmethod
2427
def subcommands() -> Dict[str, Set[str]]:

chebai/loss/bce_weighted.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import torch
2-
from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor
2+
from chebai.preprocessing.datasets.base import XYBaseDataModule
3+
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
34
import pandas as pd
45
import os
56
import pickle
@@ -10,9 +11,16 @@ class BCEWeighted(torch.nn.BCEWithLogitsLoss):
1011
https://openaccess.thecvf.com/content_CVPR_2019/papers/Cui_Class-Balanced_Loss_Based_on_Effective_Number_of_Samples_CVPR_2019_paper.pdf)
1112
"""
1213

13-
def __init__(self, beta: float = None, data_extractor: _ChEBIDataExtractor = None):
14+
def __init__(
15+
self,
16+
beta: float = None,
17+
data_extractor: XYBaseDataModule = None,
18+
):
1419
self.beta = beta
20+
if isinstance(data_extractor, LabeledUnlabeledMixed):
21+
data_extractor = data_extractor.labeled
1522
self.data_extractor = data_extractor
23+
1624
super().__init__()
1725

1826
def set_pos_weight(self, input):
@@ -31,12 +39,12 @@ def set_pos_weight(self, input):
3139
open(
3240
os.path.join(
3341
self.data_extractor.raw_dir,
34-
self.data_extractor.raw_file_names_dict[set],
42+
raw_file_name,
3543
),
3644
"rb",
3745
)
3846
)
39-
for set in ["train", "validation", "test"]
47+
for raw_file_name in self.data_extractor.raw_file_names
4048
]
4149
)
4250
value_counts = []

chebai/loss/semantic.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
from typing import Literal
88

99
from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor, ChEBIOver100
10+
from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed
11+
from chebai.loss.bce_weighted import BCEWeighted
1012

1113

1214
class ImplicationLoss(torch.nn.Module):
1315
def __init__(
1416
self,
15-
data_extractor: _ChEBIDataExtractor,
17+
data_extractor: _ChEBIDataExtractor | LabeledUnlabeledMixed,
1618
base_loss: torch.nn.Module = None,
1719
tnorm: Literal["product", "lukasiewicz", "xu19"] = "product",
1820
impl_loss_weight=0.1, # weight of implication loss in relation to base_loss
@@ -21,7 +23,13 @@ def __init__(
2123
multiply_by_softmax=False,
2224
):
2325
super().__init__()
26+
# automatically choose labeled subset for implication filter in case of mixed dataset
27+
if isinstance(data_extractor, LabeledUnlabeledMixed):
28+
data_extractor = data_extractor.labeled
2429
self.data_extractor = data_extractor
30+
# propagate data_extractor to base loss
31+
if isinstance(base_loss, BCEWeighted):
32+
base_loss.data_extractor = self.data_extractor
2533
self.base_loss = base_loss
2634
self.implication_cache_file = f"implications_{self.data_extractor.name}.cache"
2735
self.label_names = _load_label_names(
@@ -106,7 +114,7 @@ class DisjointLoss(ImplicationLoss):
106114
def __init__(
107115
self,
108116
path_to_disjointness,
109-
data_extractor: _ChEBIDataExtractor,
117+
data_extractor: _ChEBIDataExtractor | LabeledUnlabeledMixed,
110118
base_loss: torch.nn.Module = None,
111119
disjoint_loss_weight=100,
112120
**kwargs,

chebai/models/base.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,9 @@ class ChebaiBaseNet(LightningModule):
2828
2929
Attributes:
3030
NAME (str): The name of the model.
31-
LOSS (torch.nn.Module): The loss function used by the model.
3231
"""
3332

3433
NAME = None
35-
LOSS = torch.nn.BCEWithLogitsLoss
3634

3735
def __init__(
3836
self,

chebai/preprocessing/datasets/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,10 @@ def setup_processed(self):
312312
def processed_file_names(self):
313313
raise NotImplementedError
314314

315+
@property
316+
def raw_file_names(self):
317+
raise NotImplementedError
318+
315319
@property
316320
def processed_file_names_dict(self) -> dict:
317321
raise NotImplementedError

configs/loss/semantic_loss.yml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
class_path: chebai.loss.semantic.DisjointLoss
22
init_args:
33
path_to_disjointness: data/disjoint.csv
4-
data_extractor: &extractor ../data/chebi100.yml
54
base_loss:
65
class_path: chebai.loss.bce_weighted.BCEWeighted
76
init_args:
87
beta: 0.99
9-
data_extractor: *extractor
108
tnorm: product
119
impl_loss_weight: 0.01

0 commit comments

Comments
 (0)