Skip to content

Commit 347ad3e

Browse files
author
sfluegel
committed
add "loss_attention"
1 parent 1f6a5ff commit 347ad3e

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

chebai/loss/semantic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def __init__(
1818
impl_loss_weight=0.1, # weight of implication loss in relation to base_loss
1919
pos_scalar=1,
2020
pos_epsilon=0.01,
21+
loss_attention=False,
2122
):
2223
super().__init__()
2324
self.data_extractor = data_extractor
@@ -36,6 +37,7 @@ def __init__(
3637
self.impl_weight = impl_loss_weight
3738
self.pos_scalar = pos_scalar
3839
self.eps = pos_epsilon
40+
self.loss_attention = loss_attention
3941

4042
def forward(self, input, target, **kwargs):
4143
nnl = kwargs.pop("non_null_labels", None)
@@ -82,6 +84,8 @@ def _calculate_implication_loss(self, l, r):
8284
else:
8385
raise NotImplementedError(f"Unknown tnorm {self.tnorm}")
8486

87+
if self.loss_attention:
88+
individual_loss *= individual_loss.softmax(dim=-1)
8589
return torch.mean(
8690
torch.sum(individual_loss, dim=-1),
8791
dim=0,

0 commit comments

Comments
 (0)