File tree Expand file tree Collapse file tree 1 file changed +12
-0
lines changed Expand file tree Collapse file tree 1 file changed +12
-0
lines changed Original file line number Diff line number Diff line change 4545class SamplingStyle (Enum ):
4646 POISSON = 'poisson'
4747 GEOMETRIC = 'geometric'
48+ UNIFORM = 'uniform'
49+ NORMAL = 'normal'
4850
4951
5052def analyze_data_prefix (data_prefix ):
@@ -254,6 +256,8 @@ def create_masked_lm_predictions(tokens,
254256 pvals /= pvals .sum (keepdims = True )
255257 if favor_longer_ngram :
256258 pvals = pvals [::- 1 ]
259+ elif sampling_style is SamplingStyle .NORMAL :
260+ normal_mean = (max_ngrams + 1 ) / 2
257261
258262 ngram_indexes = []
259263 for idx in range (len (cand_indexes )):
@@ -287,6 +291,14 @@ def create_masked_lm_predictions(tokens,
287291 # the max_ngrams. Using p=0.2 default from the SpanBERT paper
288292 # https://arxiv.org/pdf/1907.10529.pdf (Sec 3.1)
289293 n = min (np_rng .geometric (0.2 ), max_ngrams )
294+ elif sampling_style is SamplingStyle .UNIFORM :
295+ n = np_rng .choice (ngrams [:len (cand_index_set )])
296+ elif sampling_style is SamplingStyle .NORMAL :
297+ n = round (np .clip (
298+ np_rng .normal (loc = normal_mean ),
299+ 1 ,
300+ len (cand_index_set ),
301+ ))
290302 else :
291303 raise ValueError ('unknown sampling style' )
292304
You can’t perform that action at this time.
0 commit comments