Skip to content

Commit 4203bf2

Browse files
Implement prepare_data function for LLM training
Move prepare_data to a package 1 of ... This utility function prepares tokenized input sequences and labels for training a language model. It handles special tokens, applies a sliding window approach, and returns input-label pairs along with the vocabulary size.
1 parent 6a46f77 commit 4203bf2

File tree

1 file changed

+132
-0
lines changed

1 file changed

+132
-0
lines changed

cerebrosllmutils/llm_utils.py

Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
"""
2+
Utility package with LLM components.
3+
4+
5+
6+
"""
7+
8+
9+
10+
def prepare_data(data: list[str], tokenizer, max_seq_length: int = 1024, prompt_length: int=1):
11+
"""
12+
Prepares tokenized input sequences and corresponding labels for training the Cerebros
13+
[not so] large language model.
14+
15+
This function takes raw text data, tokenizes it, and applies a sliding window approach to
16+
generate input-label pairs for next-token prediction tasks. It assumes that each sample may
17+
contain a special token `</prompt>` which separates the prompt from the completion. If this
18+
token is not present, the sample is treated as a non-instruct example and a default prompt
19+
length (1 token) is used.
20+
21+
For each token after the prompt (up to the first padding token), it creates an input sequence
22+
consisting of all tokens up to (but not including) that token, and sets the label as a one-hot
23+
encoded vector of the target token. A final sample is added where the label is the pad token,
24+
indicating the end of the sequence.
25+
26+
Parameters:
27+
-----------
28+
data : list of str
29+
List of input text samples to be processed.
30+
max_seq_length : int, optional: default = 1024
31+
Maximum sequence length for input tensors. Sequences longer than this will be truncated,
32+
and shorter ones will be padded. Defaults to `MAX_SEQ_LENGTH`.
33+
prompt_length: int, optional: Default = 1
34+
Rarely changed, deprecated (for R and D use), to be removed: The number of tokens fed to
35+
the model at training before the model is expected to start predicting the next token.
36+
tokenizer : a transformers.Tokenizer
37+
38+
Returns:
39+
--------
40+
tuple:
41+
- all_input_ids (list of list of int): list[list[int]] Token IDs for each input sequence, shaped
42+
[num_samples, max_seq_length].
43+
- all_labels (list of list of int): list[list[int]] One-hot encoded labels for next-token prediction,
44+
shaped [num_samples, vocab_size].
45+
- vocab_size (int): Size of the tokenizer's vocabulary, used for label dimensions.
46+
47+
Notes:
48+
------
49+
- Special tokens like `</prompt>` are handled manually; no automatic special token insertion.
50+
- Padding is done using the tokenizer's pad token ID to MAX_SEQ_LENGTH.
51+
- The function assumes global variables `tokenizer`, `MAX_SEQ_LENGTH`, `PROMPT_LENGTH`, and
52+
`vocab_size` are defined in the scope where this function is called.
53+
"""
54+
55+
all_input_ids = []
56+
all_labels = []
57+
58+
pad_token_id = tokenizer.pad_token_id
59+
60+
# Tokenize all data at once for efficiency
61+
tokenized_data = tokenizer(
62+
data,
63+
max_length=max_seq_length,
64+
padding='max_length',
65+
truncation=True,
66+
add_special_tokens=False # We'll handle special tokens manually
67+
)
68+
vocab_size = len(tokenizer)
69+
70+
# Get the token ID for </prompt>
71+
end_prompt_token_id = tokenizer.encode("</prompt>", add_special_tokens=False)[0]
72+
73+
# Process each sample
74+
for sample_tokens in tokenized_data['input_ids']:
75+
# Find the index of </prompt> token
76+
try:
77+
end_prompt_index = sample_tokens.index(end_prompt_token_id)
78+
except ValueError:
79+
# If </prompt> not found, treat sample as a non-instruct sample
80+
end_prompt_index = (
81+
PROMPT_LENGTH - 1) # int(np.ceil(len(sample_tokens) * (1/3))) # 0 ## 1. Give it a fair starting place to predict the next word 2. reduce the number of expanded samples
82+
83+
# Find first pad token after </prompt>
84+
first_pad_index = None
85+
for i in range(end_prompt_index + 1, len(sample_tokens)):
86+
if sample_tokens[i] == pad_token_id:
87+
first_pad_index = i
88+
break
89+
90+
# If no pad token found, use the end of sequence
91+
if first_pad_index is None:
92+
first_pad_index = len(sample_tokens)
93+
94+
# Apply sliding window from after </prompt> to first pad token
95+
# Start from end_prompt_index + 1 (first token to predict)
96+
# End at first_pad_index - 1 (last token to predict)
97+
for i in range(end_prompt_index + 1, first_pad_index):
98+
# Input: from start up to (but not including) token i
99+
input_ids = sample_tokens[:i]
100+
101+
# Pad or truncate to max_seq_length
102+
if len(input_ids) > max_seq_length:
103+
input_ids = input_ids[:max_seq_length]
104+
else:
105+
input_ids = input_ids + [pad_token_id] * (max_seq_length - len(input_ids))
106+
107+
# Label: one-hot encoding of token at position i
108+
next_token = sample_tokens[i]
109+
label = [0] * vocab_size
110+
label[next_token] = 1
111+
112+
all_input_ids.append(input_ids)
113+
all_labels.append(label)
114+
115+
# Add final sample with pad token as label to indicate termination
116+
if first_pad_index < len(sample_tokens): # Only if there's actually a pad token
117+
input_ids = sample_tokens[:first_pad_index]
118+
119+
# Pad or truncate to max_seq_length
120+
if len(input_ids) > max_seq_length:
121+
input_ids = input_ids[:max_seq_length]
122+
else:
123+
input_ids = input_ids + [pad_token_id] * (max_seq_length - len(input_ids))
124+
125+
# Label: one-hot encoding of pad token
126+
label = [0] * vocab_size
127+
label[pad_token_id] = 1
128+
129+
all_input_ids.append(input_ids)
130+
all_labels.append(label)
131+
132+
return all_input_ids, all_labels, vocab_size

0 commit comments

Comments
 (0)