Skip to content

Commit e087dc7

Browse files
author
sfluegel
committed
rename softmax operation, make it not inplace
1 parent 347ad3e commit e087dc7

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

chebai/loss/semantic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(
1818
impl_loss_weight=0.1, # weight of implication loss in relation to base_loss
1919
pos_scalar=1,
2020
pos_epsilon=0.01,
21-
loss_attention=False,
21+
multiply_by_softmax=False,
2222
):
2323
super().__init__()
2424
self.data_extractor = data_extractor
@@ -37,7 +37,7 @@ def __init__(
3737
self.impl_weight = impl_loss_weight
3838
self.pos_scalar = pos_scalar
3939
self.eps = pos_epsilon
40-
self.loss_attention = loss_attention
40+
self.multiply_by_softmax = multiply_by_softmax
4141

4242
def forward(self, input, target, **kwargs):
4343
nnl = kwargs.pop("non_null_labels", None)
@@ -84,8 +84,8 @@ def _calculate_implication_loss(self, l, r):
8484
else:
8585
raise NotImplementedError(f"Unknown tnorm {self.tnorm}")
8686

87-
if self.loss_attention:
88-
individual_loss *= individual_loss.softmax(dim=-1)
87+
if self.multiply_by_softmax:
88+
individual_loss = individual_loss * individual_loss.softmax(dim=-1)
8989
return torch.mean(
9090
torch.sum(individual_loss, dim=-1),
9191
dim=0,

0 commit comments

Comments
 (0)