77from typing import Literal
88
99from chebai .preprocessing .datasets .chebi import _ChEBIDataExtractor , ChEBIOver100
10+ from chebai .preprocessing .datasets .pubchem import LabeledUnlabeledMixed
11+ from chebai .loss .bce_weighted import BCEWeighted
1012
1113
1214class ImplicationLoss (torch .nn .Module ):
1315 def __init__ (
1416 self ,
15- data_extractor : _ChEBIDataExtractor ,
17+ data_extractor : _ChEBIDataExtractor | LabeledUnlabeledMixed ,
1618 base_loss : torch .nn .Module = None ,
1719 tnorm : Literal ["product" , "lukasiewicz" , "xu19" ] = "product" ,
1820 impl_loss_weight = 0.1 , # weight of implication loss in relation to base_loss
1921 pos_scalar = 1 ,
2022 pos_epsilon = 0.01 ,
23+ multiply_by_softmax = False ,
2124 ):
2225 super ().__init__ ()
26+ # automatically choose labeled subset for implication filter in case of mixed dataset
27+ if isinstance (data_extractor , LabeledUnlabeledMixed ):
28+ data_extractor = data_extractor .labeled
2329 self .data_extractor = data_extractor
30+ # propagate data_extractor to base loss
31+ if isinstance (base_loss , BCEWeighted ):
32+ base_loss .data_extractor = self .data_extractor
2433 self .base_loss = base_loss
2534 self .implication_cache_file = f"implications_{ self .data_extractor .name } .cache"
2635 self .label_names = _load_label_names (
@@ -36,6 +45,7 @@ def __init__(
3645 self .impl_weight = impl_loss_weight
3746 self .pos_scalar = pos_scalar
3847 self .eps = pos_epsilon
48+ self .multiply_by_softmax = multiply_by_softmax
3949
4050 def forward (self , input , target , ** kwargs ):
4151 nnl = kwargs .pop ("non_null_labels" , None )
@@ -70,16 +80,20 @@ def _calculate_implication_loss(self, l, r):
7080 math .pow (1 + self .eps , 1 / self .pos_scalar )
7181 - math .pow (self .eps , 1 / self .pos_scalar )
7282 )
73- r = torch .pow (r , self .pos_scalar )
83+ one_min_r = torch .pow (1 - r , self .pos_scalar )
84+ else :
85+ one_min_r = 1 - r
7486 if self .tnorm == "product" :
75- individual_loss = l * ( 1 - r )
87+ individual_loss = l * one_min_r
7688 elif self .tnorm == "xu19" :
77- individual_loss = - torch .log (1 - l * ( 1 - r ) )
89+ individual_loss = - torch .log (1 - l * one_min_r )
7890 elif self .tnorm == "lukasiewicz" :
79- individual_loss = torch .relu (l - r )
91+ individual_loss = torch .relu (l + one_min_r - 1 )
8092 else :
8193 raise NotImplementedError (f"Unknown tnorm { self .tnorm } " )
8294
95+ if self .multiply_by_softmax :
96+ individual_loss = individual_loss * individual_loss .softmax (dim = - 1 )
8397 return torch .mean (
8498 torch .sum (individual_loss , dim = - 1 ),
8599 dim = 0 ,
@@ -100,7 +114,7 @@ class DisjointLoss(ImplicationLoss):
100114 def __init__ (
101115 self ,
102116 path_to_disjointness ,
103- data_extractor : _ChEBIDataExtractor ,
117+ data_extractor : _ChEBIDataExtractor | LabeledUnlabeledMixed ,
104118 base_loss : torch .nn .Module = None ,
105119 disjoint_loss_weight = 100 ,
106120 ** kwargs ,
0 commit comments