|
| 1 | +""" |
| 2 | +Utility package with LLM components. |
| 3 | +
|
| 4 | +
|
| 5 | +
|
| 6 | +""" |
| 7 | + |
| 8 | + |
| 9 | + |
| 10 | +def prepare_data(data: list[str], tokenizer, max_seq_length: int = 1024, prompt_length: int=1): |
| 11 | + """ |
| 12 | + Prepares tokenized input sequences and corresponding labels for training the Cerebros |
| 13 | + [not so] large language model. |
| 14 | +
|
| 15 | + This function takes raw text data, tokenizes it, and applies a sliding window approach to |
| 16 | + generate input-label pairs for next-token prediction tasks. It assumes that each sample may |
| 17 | + contain a special token `</prompt>` which separates the prompt from the completion. If this |
| 18 | + token is not present, the sample is treated as a non-instruct example and a default prompt |
| 19 | + length (1 token) is used. |
| 20 | +
|
| 21 | + For each token after the prompt (up to the first padding token), it creates an input sequence |
| 22 | + consisting of all tokens up to (but not including) that token, and sets the label as a one-hot |
| 23 | + encoded vector of the target token. A final sample is added where the label is the pad token, |
| 24 | + indicating the end of the sequence. |
| 25 | +
|
| 26 | + Parameters: |
| 27 | + ----------- |
| 28 | + data : list of str |
| 29 | + List of input text samples to be processed. |
| 30 | + max_seq_length : int, optional: default = 1024 |
| 31 | + Maximum sequence length for input tensors. Sequences longer than this will be truncated, |
| 32 | + and shorter ones will be padded. Defaults to `MAX_SEQ_LENGTH`. |
| 33 | + prompt_length: int, optional: Default = 1 |
| 34 | + Rarely changed, deprecated (for R and D use), to be removed: The number of tokens fed to |
| 35 | + the model at training before the model is expected to start predicting the next token. |
| 36 | + tokenizer : a transformers.Tokenizer |
| 37 | +
|
| 38 | + Returns: |
| 39 | + -------- |
| 40 | + tuple: |
| 41 | + - all_input_ids (list of list of int): list[list[int]] Token IDs for each input sequence, shaped |
| 42 | + [num_samples, max_seq_length]. |
| 43 | + - all_labels (list of list of int): list[list[int]] One-hot encoded labels for next-token prediction, |
| 44 | + shaped [num_samples, vocab_size]. |
| 45 | + - vocab_size (int): Size of the tokenizer's vocabulary, used for label dimensions. |
| 46 | +
|
| 47 | + Notes: |
| 48 | + ------ |
| 49 | + - Special tokens like `</prompt>` are handled manually; no automatic special token insertion. |
| 50 | + - Padding is done using the tokenizer's pad token ID to MAX_SEQ_LENGTH. |
| 51 | + - The function assumes global variables `tokenizer`, `MAX_SEQ_LENGTH`, `PROMPT_LENGTH`, and |
| 52 | + `vocab_size` are defined in the scope where this function is called. |
| 53 | + """ |
| 54 | + |
| 55 | + all_input_ids = [] |
| 56 | + all_labels = [] |
| 57 | + |
| 58 | + pad_token_id = tokenizer.pad_token_id |
| 59 | + |
| 60 | + # Tokenize all data at once for efficiency |
| 61 | + tokenized_data = tokenizer( |
| 62 | + data, |
| 63 | + max_length=max_seq_length, |
| 64 | + padding='max_length', |
| 65 | + truncation=True, |
| 66 | + add_special_tokens=False # We'll handle special tokens manually |
| 67 | + ) |
| 68 | + vocab_size = len(tokenizer) |
| 69 | + |
| 70 | + # Get the token ID for </prompt> |
| 71 | + end_prompt_token_id = tokenizer.encode("</prompt>", add_special_tokens=False)[0] |
| 72 | + |
| 73 | + # Process each sample |
| 74 | + for sample_tokens in tokenized_data['input_ids']: |
| 75 | + # Find the index of </prompt> token |
| 76 | + try: |
| 77 | + end_prompt_index = sample_tokens.index(end_prompt_token_id) |
| 78 | + except ValueError: |
| 79 | + # If </prompt> not found, treat sample as a non-instruct sample |
| 80 | + end_prompt_index = ( |
| 81 | + PROMPT_LENGTH - 1) # int(np.ceil(len(sample_tokens) * (1/3))) # 0 ## 1. Give it a fair starting place to predict the next word 2. reduce the number of expanded samples |
| 82 | + |
| 83 | + # Find first pad token after </prompt> |
| 84 | + first_pad_index = None |
| 85 | + for i in range(end_prompt_index + 1, len(sample_tokens)): |
| 86 | + if sample_tokens[i] == pad_token_id: |
| 87 | + first_pad_index = i |
| 88 | + break |
| 89 | + |
| 90 | + # If no pad token found, use the end of sequence |
| 91 | + if first_pad_index is None: |
| 92 | + first_pad_index = len(sample_tokens) |
| 93 | + |
| 94 | + # Apply sliding window from after </prompt> to first pad token |
| 95 | + # Start from end_prompt_index + 1 (first token to predict) |
| 96 | + # End at first_pad_index - 1 (last token to predict) |
| 97 | + for i in range(end_prompt_index + 1, first_pad_index): |
| 98 | + # Input: from start up to (but not including) token i |
| 99 | + input_ids = sample_tokens[:i] |
| 100 | + |
| 101 | + # Pad or truncate to max_seq_length |
| 102 | + if len(input_ids) > max_seq_length: |
| 103 | + input_ids = input_ids[:max_seq_length] |
| 104 | + else: |
| 105 | + input_ids = input_ids + [pad_token_id] * (max_seq_length - len(input_ids)) |
| 106 | + |
| 107 | + # Label: one-hot encoding of token at position i |
| 108 | + next_token = sample_tokens[i] |
| 109 | + label = [0] * vocab_size |
| 110 | + label[next_token] = 1 |
| 111 | + |
| 112 | + all_input_ids.append(input_ids) |
| 113 | + all_labels.append(label) |
| 114 | + |
| 115 | + # Add final sample with pad token as label to indicate termination |
| 116 | + if first_pad_index < len(sample_tokens): # Only if there's actually a pad token |
| 117 | + input_ids = sample_tokens[:first_pad_index] |
| 118 | + |
| 119 | + # Pad or truncate to max_seq_length |
| 120 | + if len(input_ids) > max_seq_length: |
| 121 | + input_ids = input_ids[:max_seq_length] |
| 122 | + else: |
| 123 | + input_ids = input_ids + [pad_token_id] * (max_seq_length - len(input_ids)) |
| 124 | + |
| 125 | + # Label: one-hot encoding of pad token |
| 126 | + label = [0] * vocab_size |
| 127 | + label[pad_token_id] = 1 |
| 128 | + |
| 129 | + all_input_ids.append(input_ids) |
| 130 | + all_labels.append(label) |
| 131 | + |
| 132 | + return all_input_ids, all_labels, vocab_size |
0 commit comments