From e0aea11a1704898aaaff9a6a8ffe5c492c9bc402 Mon Sep 17 00:00:00 2001 From: "codeflash-ai[bot]" <148906541+codeflash-ai[bot]@users.noreply.github.com> Date: Wed, 22 Oct 2025 22:27:27 +0000 Subject: [PATCH] Optimize mask_tokens_evenly The optimization achieves a **2936% speedup** by eliminating redundant computation in the inner loops through strategic precomputation and algorithmic improvements. **Key optimizations:** 1. **Precomputation of expensive operations**: The original code called `is_token_large_enough()` for every token in every modulus iteration, resulting in 345,648 function calls. The optimized version precomputes a `large_enough_flags` list once upfront, reducing function calls to just 5,930 - a **98% reduction**. 2. **Eliminated redundant `next_token` lookups**: The original repeatedly computed `next_token = '' if idx + 1 == len(tokens) else tokens[idx + 1]` for each modulus. The optimization precomputes a `next_tokens` list once, avoiding 345,648 conditional checks. 3. **Batch masking with efficient indexing**: Instead of checking masking conditions for every token in every modulus, the optimized version precomputes all mask indices per modulus using list comprehensions and explicit range logic, then applies masks in a single pass. 4. **Direct list copying**: Replaced `masked_input.append()` calls with `tokens.copy()` and direct index assignment, reducing list operations from ~345K appends to simple copies plus targeted assignments. **Performance characteristics by test case:** - **Small inputs** (2-10 tokens): 5-35% slower due to precomputation overhead - **Medium inputs** (50-100 tokens): 225-1522% faster as precomputation pays off - **Large inputs** (1000+ tokens): 629-6278% faster, with greatest gains when many tokens need size checking The optimization is most effective for large token sequences where the cost of precomputation is amortized across many modulus iterations, making it ideal for production NLP workloads with substantial text inputs. --- blanc/utils.py | 69 ++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 50 insertions(+), 19 deletions(-) diff --git a/blanc/utils.py b/blanc/utils.py index fa97402..c7f1fc9 100644 --- a/blanc/utils.py +++ b/blanc/utils.py @@ -175,30 +175,61 @@ def mask_tokens_evenly(tokens, gap, min_token_lengths, mask_token, gap_mask=1): all_answers (List[Dict[int, str]]): a list of "answer" dicts, where each answer dict maps token indices corresponding to masked tokens back to their original token. """ - gap = min(gap, len(tokens)) + token_count = len(tokens) + gap = min(gap, token_count) + if gap < 1 or token_count == 0: + return [], [] + + # precompute next_tokens list + next_tokens = [''] * token_count + if token_count > 1: + next_tokens[:-1] = tokens[1:] + # precompute large_enough list (avoids recomputing for each modulus) + large_enough_flags = [ + is_token_large_enough(tokens[idx], next_tokens[idx], min_token_lengths) + for idx in range(token_count) + ] + masked_inputs = [] all_answers = [] + + # Masking logic for modulus in range(gap): - masked_input = [] + masked_input = tokens.copy() answers = {} - for idx, token in enumerate(tokens): - next_token = '' if idx + 1 == len(tokens) else tokens[idx + 1] - large_enough = is_token_large_enough(token, next_token, min_token_lengths) - - idx_off = idx % gap - if gap == 1: - can_mask = True - elif modulus + gap_mask >= gap: - can_mask = idx_off >= modulus or idx_off < (modulus + gap_mask)%gap - else: - can_mask = idx_off >= modulus and idx_off < modulus + gap_mask - if can_mask and large_enough: - masked_input.append(mask_token) - answers[idx] = token - else: - masked_input.append(token) - if len(answers) > 0: + # Efficiently compute masking indices + if gap == 1: + mask_indices = [idx for idx, large in enumerate(large_enough_flags) if large] + else: + # Precompute modulus + gap_mask for loop + mod_plus_mask = modulus + gap_mask + if mod_plus_mask >= gap: + cutoff = mod_plus_mask % gap + mask_indices = [ + idx for idx in range(token_count) + if large_enough_flags[idx] and + (idx % gap >= modulus or idx % gap < cutoff) + ] + else: + mask_indices = [ + idx for idx in range(modulus, mod_plus_mask) + if idx < token_count and large_enough_flags[idx] + ] + # Advance to cover all gap cycles + base = modulus + gap + while base < token_count: + for j in range(gap_mask): + idx = base + j + if idx < token_count and large_enough_flags[idx]: + mask_indices.append(idx) + base += gap + + for idx in mask_indices: + masked_input[idx] = mask_token + answers[idx] = tokens[idx] + + if answers: masked_inputs.append(masked_input) all_answers.append(answers)