Skip to content

Commit d8db189

Browse files
committed
Add more masked LM sampling styles
Namely sampling from uniform and normal distributions.
1 parent 7f50532 commit d8db189

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

megatron/data/dataset_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
class SamplingStyle(Enum):
4646
POISSON = 'poisson'
4747
GEOMETRIC = 'geometric'
48+
UNIFORM = 'uniform'
49+
NORMAL = 'normal'
4850

4951

5052
def analyze_data_prefix(data_prefix):
@@ -254,6 +256,8 @@ def create_masked_lm_predictions(tokens,
254256
pvals /= pvals.sum(keepdims=True)
255257
if favor_longer_ngram:
256258
pvals = pvals[::-1]
259+
elif sampling_style is SamplingStyle.NORMAL:
260+
normal_mean = (max_ngrams + 1) / 2
257261

258262
ngram_indexes = []
259263
for idx in range(len(cand_indexes)):
@@ -287,6 +291,14 @@ def create_masked_lm_predictions(tokens,
287291
# the max_ngrams. Using p=0.2 default from the SpanBERT paper
288292
# https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1)
289293
n = min(np_rng.geometric(0.2), max_ngrams)
294+
elif sampling_style is SamplingStyle.UNIFORM:
295+
n = np_rng.choice(ngrams[:len(cand_index_set)])
296+
elif sampling_style is SamplingStyle.NORMAL:
297+
n = round(np.clip(
298+
np_rng.normal(loc=normal_mean),
299+
1,
300+
len(cand_index_set),
301+
))
290302
else:
291303
raise ValueError('unknown sampling style')
292304

0 commit comments

Comments
 (0)