Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions blanc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,25 +222,29 @@ def mask_tokens_randomly(tokens, min_token_lengths, mask_token, p_mask):
"""
n_mask = max(int(len(tokens) * p_mask), 1)

token_positions = []
for idx, token in enumerate(tokens):
next_token = '' if idx + 1 == len(tokens) else tokens[idx + 1]
if is_token_large_enough(token, next_token, min_token_lengths):
token_positions.append(idx)
# Precompute next_token list once for efficiency
next_tokens = tokens[1:] + ['']
token_positions = [
idx for idx, (token, next_token) in enumerate(zip(tokens, next_tokens))
if is_token_large_enough(token, next_token, min_token_lengths)
]

random.shuffle(token_positions)

all_inputs, all_answers = [], []
while len(token_positions) > 0:
positions_to_mask = token_positions[:n_mask]
token_positions = token_positions[n_mask:]
# Avoid repeated lookup of positions_to_mask
all_inputs = []
all_answers = []
position_count = len(token_positions)
if position_count == 0:
return all_inputs, all_answers

inputs, answers = [], {}
for idx, token in enumerate(tokens):
if idx in positions_to_mask:
inputs.append(mask_token)
answers[idx] = token
else:
inputs.append(token)
# Make a copy to avoid resizing list in every iteration (efficient slicing)
for start in range(0, position_count, n_mask):
positions_to_mask = set(token_positions[start:start + n_mask])

# Use list comprehension for inputs; dict comprehension for answers
inputs = [mask_token if idx in positions_to_mask else token for idx, token in enumerate(tokens)]
answers = {idx: tokens[idx] for idx in positions_to_mask}

all_inputs.append(inputs)
all_answers.append(answers)
Expand Down