Skip to content

Commit 7e4d85d

Browse files
Update llm_utils.py
Moved iRoPE embedding to cerebrosllmutils.
1 parent 6955ac7 commit 7e4d85d

File tree

1 file changed

+152
-11
lines changed

1 file changed

+152
-11
lines changed

cerebrosllmutils/llm_utils.py

Lines changed: 152 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88

99
from typing import List, Tuple, Any
10+
import tensorflow as tf
1011

1112

1213

@@ -16,18 +17,18 @@ def prepare_data(
1617
max_seq_length: int = 1024,
1718
prompt_length: int = 1) -> Tuple[List[List[int]], List[List[int]], int]:
1819
"""
19-
Prepares tokenized input sequences and corresponding labels for training the Cerebros
20+
Prepares tokenized input sequences and corresponding labels for training the Cerebros
2021
[not so] large language model.
2122
22-
This function takes raw text data, tokenizes it, and applies a sliding window approach to
23-
generate input-label pairs for next-token prediction tasks. It assumes that each sample may
24-
contain a special token `</prompt>` which separates the prompt from the completion. If this
25-
token is not present, the sample is treated as a non-instruct example and a default prompt
23+
This function takes raw text data, tokenizes it, and applies a sliding window approach to
24+
generate input-label pairs for next-token prediction tasks. It assumes that each sample may
25+
contain a special token `</prompt>` which separates the prompt from the completion. If this
26+
token is not present, the sample is treated as a non-instruct example and a default prompt
2627
length (1 token) is used.
2728
28-
For each token after the prompt (up to the first padding token), it creates an input sequence
29-
consisting of all tokens up to (but not including) that token, and sets the label as a one-hot
30-
encoded vector of the target token. A final sample is added where the label is the pad token,
29+
For each token after the prompt (up to the first padding token), it creates an input sequence
30+
consisting of all tokens up to (but not including) that token, and sets the label as a one-hot
31+
encoded vector of the target token. A final sample is added where the label is the pad token,
3132
indicating the end of the sequence.
3233
3334
Parameters:
@@ -45,7 +46,7 @@ def prepare_data(
4546
Returns:
4647
--------
4748
tuple:
48-
- all_input_ids (2d list of int): Tuple[List[List[int]]] Token IDs for each input sequence, shaped
49+
- all_input_ids (2d list of int): Tuple[List[List[int]]] Token IDs for each input sequence, shaped
4950
[num_samples, max_seq_length].
5051
- all_labels (2d list of int): Tuple[List[List[int]]] One-hot encoded labels for next-token prediction,
5152
shaped [num_samples, vocab_size].
@@ -55,7 +56,7 @@ def prepare_data(
5556
------
5657
- Special tokens like `</prompt>` are handled manually; no automatic special token insertion.
5758
- Padding is done using the tokenizer's pad token ID to MAX_SEQ_LENGTH.
58-
- The function assumes global variables `tokenizer`, `MAX_SEQ_LENGTH`, `PROMPT_LENGTH`, and
59+
- The function assumes global variables `tokenizer`, `MAX_SEQ_LENGTH`, `PROMPT_LENGTH`, and
5960
`vocab_size` are defined in the scope where this function is called.
6061
"""
6162

@@ -85,7 +86,7 @@ def prepare_data(
8586
except ValueError:
8687
# If </prompt> not found, treat sample as a non-instruct sample
8788
end_prompt_index = (
88-
prompt_length - 1) # int(np.ceil(len(sample_tokens) * (1/3))) # 0 ## 1. Give it a fair starting place to predict the next word 2. reduce the number of expanded samples
89+
prompt_length - 1) # int(np.ceil(len(sample_tokens) * (1/3))) # 0 ## 1. Give it a fair starting place to predict the next word 2. reduce the number of expanded samples
8990

9091
# Find first pad token after </prompt>
9192
first_pad_index = None
@@ -137,3 +138,143 @@ def prepare_data(
137138
all_labels.append(label)
138139

139140
return all_input_ids, all_labels, vocab_size
141+
142+
143+
# --- Base Rotary Positional Embedding
144+
@tf.keras.utils.register_keras_serializable()
145+
class RotaryEmbedding(tf.keras.layers.Layer):
146+
def __init__(self, dim, max_seq_len=1024, temperature=10000.0, **kwargs):
147+
super().__init__(**kwargs)
148+
self.dim = dim
149+
# Ensure dim is even right at initialization
150+
if self.dim % 2 != 0:
151+
raise ValueError(f"Embedding dimension `dim` ({self.dim}) must be even for RotaryEmbedding.")
152+
self.max_seq_len = max_seq_len
153+
self.temperature = temperature
154+
# *** No calculation or storage of inv_freq here or in build ***
155+
156+
def build(self, input_shape):
157+
# Build should primarily be for creating trainable weights, which we don't have.
158+
# Call super().build() for Keras compatibility.
159+
super().build(input_shape)
160+
161+
def call(self, x): # Removed seq_len argument, calculate from x
162+
shape = tf.shape(x)
163+
batch_size = shape[0]
164+
actual_seq_len = shape[1]
165+
166+
# *** Calculate inv_freq inside call ***
167+
inv_freq_base = tf.range(0, self.dim, 2, dtype=tf.float32)
168+
inv_freq = 1.0 / (self.temperature ** (inv_freq_base / self.dim))
169+
# Ensure inv_freq has the correct shape [dim/2]
170+
inv_freq = tf.cast(inv_freq, dtype=x.dtype) # Match dtype early
171+
172+
# Use actual_seq_len for calculations
173+
position = tf.range(actual_seq_len, dtype=x.dtype) # Match dtype
174+
175+
# Calculate sinusoid input using einsum or broadcasting
176+
# Einsum approach: Ensure correct dimensions [seq_len, dim/2]
177+
sinusoid_inp = tf.einsum("i,j->ij", position, inv_freq)
178+
179+
# Calculate sin and cos based on the actual sequence length
180+
sin = tf.sin(sinusoid_inp)
181+
cos = tf.cos(sinusoid_inp)
182+
183+
# Repeat sin/cos for interleaving: [a, b] -> [a, a, b, b]
184+
# Result needs shape [actual_seq_len, dim]
185+
sin = tf.repeat(sin, 2, axis=-1)
186+
cos = tf.repeat(cos, 2, axis=-1)
187+
188+
# Expand dims for batch and tile
189+
# Output shape needs to be [batch_size, actual_seq_len, dim]
190+
# Add batch dimension: [1, actual_seq_len, dim]
191+
sin = tf.expand_dims(sin, axis=0)
192+
cos = tf.expand_dims(cos, axis=0)
193+
194+
# Tile to match the batch size: [batch_size, actual_seq_len, dim]
195+
sin = tf.tile(sin, [batch_size, 1, 1])
196+
cos = tf.tile(cos, [batch_size, 1, 1])
197+
198+
# Casting to x.dtype was already done for inv_freq, sin/cos will inherit
199+
# sin = tf.cast(sin, x.dtype) # Already done via calculation chain
200+
# cos = tf.cast(cos, x.dtype) # Already done via calculation chain
201+
202+
# Return sin and cos needed by InterleavedRoPE
203+
return sin, cos
204+
205+
def get_config(self):
206+
config = super().get_config()
207+
config.update({
208+
"dim": self.dim,
209+
"max_seq_len": self.max_seq_len,
210+
"temperature": self.temperature,
211+
})
212+
return config
213+
214+
@classmethod
215+
def from_config(cls, config):
216+
return cls(**config)
217+
218+
219+
# iRoPE helper functions
220+
221+
@tf.keras.utils.register_keras_serializable()
222+
def split_alternate(x):
223+
shape = tf.shape(x)
224+
x = tf.reshape(x, [shape[0], shape[1], shape[2] // 2, 2])
225+
x = tf.transpose(x, [0, 1, 3, 2])
226+
x = tf.reshape(x, [shape[0], shape[1], -1])
227+
return x
228+
229+
230+
@tf.keras.utils.register_keras_serializable()
231+
def rotate_half(x):
232+
x = split_alternate(x)
233+
d = tf.shape(x)[-1]
234+
rotated_x = tf.concat([-x[..., d // 2:], x[..., :d // 2]], axis=-1)
235+
return tf.reshape(rotated_x, tf.shape(x))
236+
237+
238+
@tf.keras.utils.register_keras_serializable()
239+
def apply_rotary_pos_emb(x, sin, cos):
240+
cos = tf.reshape(cos, [tf.shape(cos)[0], tf.shape(cos)[1], -1])
241+
sin = tf.reshape(sin, [tf.shape(sin)[0], tf.shape(sin)[1], -1])
242+
x_rotated = x * cos + rotate_half(x) * sin
243+
return x_rotated
244+
245+
246+
# interleaved Rotary Postional Embedding (iRoPE)
247+
@tf.keras.utils.register_keras_serializable()
248+
class InterleavedRoPE(tf.keras.layers.Layer):
249+
def __init__(self, dim, max_seq_len=1024, **kwargs):
250+
super().__init__(**kwargs)
251+
if dim % 2 != 0:
252+
raise ValueError(f"Embedding dimension `dim` ({dim}) must be even for InterleavedRoPE.")
253+
self.dim = dim
254+
self.max_seq_len = max_seq_len
255+
# Instantiate the RotaryEmbedding layer
256+
# Ensure the name is consistent if needed for saving/loading
257+
self.rotary_emb = RotaryEmbedding(dim, max_seq_len, name="rotary_embedding")
258+
259+
def call(self, x):
260+
# Get sin and cos from the RotaryEmbedding layer's call method
261+
# *** Pass only 'x'. RotaryEmbedding calculates seq_len internally. ***
262+
sin, cos = self.rotary_emb(x)
263+
264+
# Apply the positional embeddings
265+
x_embedded = apply_rotary_pos_emb(x, sin, cos)
266+
return x_embedded
267+
268+
def get_config(self):
269+
config = super().get_config()
270+
config.update({
271+
"dim": self.dim,
272+
"max_seq_len": self.max_seq_len,
273+
})
274+
# Keras handles nested layer serialization automatically
275+
return config
276+
277+
@classmethod
278+
def from_config(cls, config):
279+
# Keras handles nested layer restoration automatically
280+
return cls(**config)

0 commit comments

Comments
 (0)