Skip to content

Commit 983fb5b

Browse files
Update llm_utils.py
Add metadata to @tf.keras.utils.register_keras_serializable(') ...
1 parent dcfe66d commit 983fb5b

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

cerebrosllmutils/llm_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def prepare_data(
141141

142142

143143
# --- Base Rotary Positional Embedding
144-
@tf.keras.utils.register_keras_serializable()
144+
@tf.keras.utils.register_keras_serializable(package='cerebrosllmutils', name='RotaryEmbedding')
145145
class RotaryEmbedding(tf.keras.layers.Layer):
146146
def __init__(self, dim, max_seq_len=1024, temperature=10000.0, **kwargs):
147147
super().__init__(**kwargs)
@@ -218,7 +218,7 @@ def from_config(cls, config):
218218

219219
# iRoPE helper functions
220220

221-
@tf.keras.utils.register_keras_serializable()
221+
@tf.keras.utils.register_keras_serializable(package='cerebrosllmutils', name='split_alternate')
222222
def split_alternate(x):
223223
shape = tf.shape(x)
224224
x = tf.reshape(x, [shape[0], shape[1], shape[2] // 2, 2])
@@ -227,15 +227,15 @@ def split_alternate(x):
227227
return x
228228

229229

230-
@tf.keras.utils.register_keras_serializable()
230+
@tf.keras.utils.register_keras_serializable(package='cerebrosllmutils', name='rotate_half')
231231
def rotate_half(x):
232232
x = split_alternate(x)
233233
d = tf.shape(x)[-1]
234234
rotated_x = tf.concat([-x[..., d // 2:], x[..., :d // 2]], axis=-1)
235235
return tf.reshape(rotated_x, tf.shape(x))
236236

237237

238-
@tf.keras.utils.register_keras_serializable()
238+
@tf.keras.utils.register_keras_serializable(package='cerebrosllmutils', name='apply_rotary_pos_emb')
239239
def apply_rotary_pos_emb(x, sin, cos):
240240
cos = tf.reshape(cos, [tf.shape(cos)[0], tf.shape(cos)[1], -1])
241241
sin = tf.reshape(sin, [tf.shape(sin)[0], tf.shape(sin)[1], -1])
@@ -244,7 +244,7 @@ def apply_rotary_pos_emb(x, sin, cos):
244244

245245

246246
# interleaved Rotary Postional Embedding (iRoPE)
247-
@tf.keras.utils.register_keras_serializable()
247+
@tf.keras.utils.register_keras_serializable(package='cerebrosllmutils', name='InterleavedRoPE')
248248
class InterleavedRoPE(tf.keras.layers.Layer):
249249
def __init__(self, dim, max_seq_len=1024, **kwargs):
250250
super().__init__(**kwargs)
@@ -280,7 +280,7 @@ def from_config(cls, config):
280280
return cls(**config)
281281

282282

283-
@tf.keras.utils.register_keras_serializable()
283+
@tf.keras.utils.register_keras_serializable(package='cerebrosllmutils', name='Perplexity')
284284
class Perplexity(tf.keras.metrics.Metric):
285285
"""
286286
Computes perplexity, defined as e^(categorical crossentropy).
@@ -310,7 +310,7 @@ def reset_state(self):
310310
self.total_crossentropy.assign(0.0)
311311
self.count.assign(0.0)
312312

313-
@tf.keras.utils.register_keras_serializable()
313+
@tf.keras.utils.register_keras_serializable(package='cerebrosllmutils', name='CerebrosNotGPTConfig')
314314
class CerebrosNotGPTConfig:
315315
def __init__(self, max_sequence_length=1536, padding_token=None):
316316
self.max_sequence_length = max_sequence_length
@@ -328,7 +328,7 @@ def from_config(cls, config):
328328
return cls(**config) # No model_0 to handle
329329

330330

331-
@tf.keras.utils.register_keras_serializable()
331+
@tf.keras.utils.register_keras_serializable(package='cerebrosllmutils', name='CerebrosNotGPT')
332332
class CerebrosNotGPT(tf.keras.Model):
333333
def __init__(self, config, model_0=None, **kwargs):
334334
super().__init__(**kwargs)

0 commit comments

Comments
 (0)