diff --git a/blanc/utils.py b/blanc/utils.py index fa97402..c7f1fc9 100644 --- a/blanc/utils.py +++ b/blanc/utils.py @@ -175,30 +175,61 @@ def mask_tokens_evenly(tokens, gap, min_token_lengths, mask_token, gap_mask=1): all_answers (List[Dict[int, str]]): a list of "answer" dicts, where each answer dict maps token indices corresponding to masked tokens back to their original token. """ - gap = min(gap, len(tokens)) + token_count = len(tokens) + gap = min(gap, token_count) + if gap < 1 or token_count == 0: + return [], [] + + # precompute next_tokens list + next_tokens = [''] * token_count + if token_count > 1: + next_tokens[:-1] = tokens[1:] + # precompute large_enough list (avoids recomputing for each modulus) + large_enough_flags = [ + is_token_large_enough(tokens[idx], next_tokens[idx], min_token_lengths) + for idx in range(token_count) + ] + masked_inputs = [] all_answers = [] + + # Masking logic for modulus in range(gap): - masked_input = [] + masked_input = tokens.copy() answers = {} - for idx, token in enumerate(tokens): - next_token = '' if idx + 1 == len(tokens) else tokens[idx + 1] - large_enough = is_token_large_enough(token, next_token, min_token_lengths) - - idx_off = idx % gap - if gap == 1: - can_mask = True - elif modulus + gap_mask >= gap: - can_mask = idx_off >= modulus or idx_off < (modulus + gap_mask)%gap - else: - can_mask = idx_off >= modulus and idx_off < modulus + gap_mask - if can_mask and large_enough: - masked_input.append(mask_token) - answers[idx] = token - else: - masked_input.append(token) - if len(answers) > 0: + # Efficiently compute masking indices + if gap == 1: + mask_indices = [idx for idx, large in enumerate(large_enough_flags) if large] + else: + # Precompute modulus + gap_mask for loop + mod_plus_mask = modulus + gap_mask + if mod_plus_mask >= gap: + cutoff = mod_plus_mask % gap + mask_indices = [ + idx for idx in range(token_count) + if large_enough_flags[idx] and + (idx % gap >= modulus or idx % gap < cutoff) + ] + else: + mask_indices = [ + idx for idx in range(modulus, mod_plus_mask) + if idx < token_count and large_enough_flags[idx] + ] + # Advance to cover all gap cycles + base = modulus + gap + while base < token_count: + for j in range(gap_mask): + idx = base + j + if idx < token_count and large_enough_flags[idx]: + mask_indices.append(idx) + base += gap + + for idx in mask_indices: + masked_input[idx] = mask_token + answers[idx] = tokens[idx] + + if answers: masked_inputs.append(masked_input) all_answers.append(answers)