@@ -123,18 +123,15 @@ def __call__(
123123
124124 # Token dropout: zero masked tokens + rescale based on observed mask ratio
125125 if self .token_dropout :
126+ # x.masked_fill_((tokens == self.mask_idx).unsqueeze(-1), 0.0)
126127 mask_positions = mx .equal (tokens , self .mask_idx )
127- x = mx .where (mask_positions [:, :, None ], mx .zeros_like (x ), x )
128-
128+ x = mx .where (mask_positions [:, :, None ], 0.0 , x )
129+
130+ # x: B x T x C
129131 mask_ratio_train = 0.15 * 0.8
130- src_lengths = mx .sum (~ padding_mask , axis = - 1 , keepdims = True )
131- mask_ratio_observed = (
132- mx .sum (mask_positions , axis = - 1 , keepdims = True ) / src_lengths
133- )
134- scale_factor = (1 - mask_ratio_train ) / mx .maximum (
135- 1 - mask_ratio_observed , 1e-8
136- )
137- x = x * scale_factor [:, None , :]
132+ src_lengths = mx .sum (~ padding_mask , axis = - 1 ) # Shape: (B,)
133+ mask_ratio_observed = mx .sum (mask_positions , axis = - 1 ).astype (x .dtype ) / src_lengths # Shape: (B,)
134+ x = x * (1 - mask_ratio_train ) / (1 - mask_ratio_observed )[:, None , None ]
138135
139136 # Zero out padding positions
140137 if padding_mask .any ():
0 commit comments