Skip to content

Commit 142d7dc

Browse files
committed
fix masked sum
1 parent edb802e commit 142d7dc

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

trinity/algorithm/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,13 @@ def aggregate_loss(values, mask, loss_agg_mode="token-mean", normalizer=None):
5151

5252
def masked_sum(values, mask, axis=None):
5353
"""Compute mean of tensor with a masked values."""
54-
return (values * mask).sum(axis=axis)
54+
valid_values = torch.where(mask.bool(), values, 0.0)
55+
return (valid_values * mask).sum(axis=axis)
5556

5657

5758
def masked_mean(values, mask, axis=None):
5859
"""Compute mean of tensor with a masked values."""
59-
return (values * mask).sum(axis=axis) / (mask.sum(axis=axis) + 1e-8)
60+
return masked_sum(values, mask, axis=axis) / (mask.sum(axis=axis) + 1e-8)
6061

6162

6263
def masked_var(values, mask, unbiased=True):

0 commit comments

Comments
 (0)