Skip to content

Commit ee66ad1

Browse files
committed
Fixed token dropout
1 parent d7c5f46 commit ee66ad1

File tree

1 file changed

+7
-10
lines changed

1 file changed

+7
-10
lines changed

esm/esm/model.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -123,18 +123,15 @@ def __call__(
123123

124124
# Token dropout: zero masked tokens + rescale based on observed mask ratio
125125
if self.token_dropout:
126+
# x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0)
126127
mask_positions = mx.equal(tokens, self.mask_idx)
127-
x = mx.where(mask_positions[:, :, None], mx.zeros_like(x), x)
128-
128+
x = mx.where(mask_positions[:, :, None], 0.0, x)
129+
130+
# x: B x T x C
129131
mask_ratio_train = 0.15 * 0.8
130-
src_lengths = mx.sum(~padding_mask, axis=-1, keepdims=True)
131-
mask_ratio_observed = (
132-
mx.sum(mask_positions, axis=-1, keepdims=True) / src_lengths
133-
)
134-
scale_factor = (1 - mask_ratio_train) / mx.maximum(
135-
1 - mask_ratio_observed, 1e-8
136-
)
137-
x = x * scale_factor[:, None, :]
132+
src_lengths = mx.sum(~padding_mask, axis=-1) # Shape: (B,)
133+
mask_ratio_observed = mx.sum(mask_positions, axis=-1).astype(x.dtype) / src_lengths # Shape: (B,)
134+
x = x * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]
138135

139136
# Zero out padding positions
140137
if padding_mask.any():

0 commit comments

Comments
 (0)