Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 50 additions & 19 deletions blanc/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,30 +175,61 @@ def mask_tokens_evenly(tokens, gap, min_token_lengths, mask_token, gap_mask=1):
all_answers (List[Dict[int, str]]): a list of "answer" dicts, where each answer dict maps
token indices corresponding to masked tokens back to their original token.
"""
gap = min(gap, len(tokens))
token_count = len(tokens)
gap = min(gap, token_count)
if gap < 1 or token_count == 0:
return [], []

# precompute next_tokens list
next_tokens = [''] * token_count
if token_count > 1:
next_tokens[:-1] = tokens[1:]
# precompute large_enough list (avoids recomputing for each modulus)
large_enough_flags = [
is_token_large_enough(tokens[idx], next_tokens[idx], min_token_lengths)
for idx in range(token_count)
]

masked_inputs = []
all_answers = []

# Masking logic
for modulus in range(gap):
masked_input = []
masked_input = tokens.copy()
answers = {}
for idx, token in enumerate(tokens):
next_token = '' if idx + 1 == len(tokens) else tokens[idx + 1]
large_enough = is_token_large_enough(token, next_token, min_token_lengths)

idx_off = idx % gap
if gap == 1:
can_mask = True
elif modulus + gap_mask >= gap:
can_mask = idx_off >= modulus or idx_off < (modulus + gap_mask)%gap
else:
can_mask = idx_off >= modulus and idx_off < modulus + gap_mask
if can_mask and large_enough:
masked_input.append(mask_token)
answers[idx] = token
else:
masked_input.append(token)

if len(answers) > 0:
# Efficiently compute masking indices
if gap == 1:
mask_indices = [idx for idx, large in enumerate(large_enough_flags) if large]
else:
# Precompute modulus + gap_mask for loop
mod_plus_mask = modulus + gap_mask
if mod_plus_mask >= gap:
cutoff = mod_plus_mask % gap
mask_indices = [
idx for idx in range(token_count)
if large_enough_flags[idx] and
(idx % gap >= modulus or idx % gap < cutoff)
]
else:
mask_indices = [
idx for idx in range(modulus, mod_plus_mask)
if idx < token_count and large_enough_flags[idx]
]
# Advance to cover all gap cycles
base = modulus + gap
while base < token_count:
for j in range(gap_mask):
idx = base + j
if idx < token_count and large_enough_flags[idx]:
mask_indices.append(idx)
base += gap

for idx in mask_indices:
masked_input[idx] = mask_token
answers[idx] = tokens[idx]

if answers:
masked_inputs.append(masked_input)
all_answers.append(answers)

Expand Down