diff --git a/blanc/utils.py b/blanc/utils.py index fa97402..d261b5e 100644 --- a/blanc/utils.py +++ b/blanc/utils.py @@ -222,25 +222,29 @@ def mask_tokens_randomly(tokens, min_token_lengths, mask_token, p_mask): """ n_mask = max(int(len(tokens) * p_mask), 1) - token_positions = [] - for idx, token in enumerate(tokens): - next_token = '' if idx + 1 == len(tokens) else tokens[idx + 1] - if is_token_large_enough(token, next_token, min_token_lengths): - token_positions.append(idx) + # Precompute next_token list once for efficiency + next_tokens = tokens[1:] + [''] + token_positions = [ + idx for idx, (token, next_token) in enumerate(zip(tokens, next_tokens)) + if is_token_large_enough(token, next_token, min_token_lengths) + ] + random.shuffle(token_positions) - all_inputs, all_answers = [], [] - while len(token_positions) > 0: - positions_to_mask = token_positions[:n_mask] - token_positions = token_positions[n_mask:] + # Avoid repeated lookup of positions_to_mask + all_inputs = [] + all_answers = [] + position_count = len(token_positions) + if position_count == 0: + return all_inputs, all_answers - inputs, answers = [], {} - for idx, token in enumerate(tokens): - if idx in positions_to_mask: - inputs.append(mask_token) - answers[idx] = token - else: - inputs.append(token) + # Make a copy to avoid resizing list in every iteration (efficient slicing) + for start in range(0, position_count, n_mask): + positions_to_mask = set(token_positions[start:start + n_mask]) + + # Use list comprehension for inputs; dict comprehension for answers + inputs = [mask_token if idx in positions_to_mask else token for idx, token in enumerate(tokens)] + answers = {idx: tokens[idx] for idx in positions_to_mask} all_inputs.append(inputs) all_answers.append(answers)