@@ -430,6 +430,7 @@ def from_config(cls, config):
430430
431431# iRoPE helper functions
432432
433+ @tf .keras .utils .register_keras_serializable ()
433434def split_alternate (x ):
434435 shape = tf .shape (x )
435436 x = tf .reshape (x , [shape [0 ], shape [1 ], shape [2 ] // 2 , 2 ])
@@ -438,13 +439,15 @@ def split_alternate(x):
438439 return x
439440
440441
442+ @tf .keras .utils .register_keras_serializable ()
441443def rotate_half (x ):
442444 x = split_alternate (x )
443445 d = tf .shape (x )[- 1 ]
444446 rotated_x = tf .concat ([- x [..., d // 2 :], x [..., :d // 2 ]], axis = - 1 )
445447 return tf .reshape (rotated_x , tf .shape (x ))
446448
447449
450+ @tf .keras .utils .register_keras_serializable ()
448451def apply_rotary_pos_emb (x , sin , cos ):
449452 cos = tf .reshape (cos , [tf .shape (cos )[0 ], tf .shape (cos )[1 ], - 1 ])
450453 sin = tf .reshape (sin , [tf .shape (sin )[0 ], tf .shape (sin )[1 ], - 1 ])
@@ -533,6 +536,7 @@ def from_config(cls, config):
533536
534537# Custom metric: Perplexity:
535538
539+ @tf .keras .utils .register_keras_serializable ()
536540class Perplexity (tf .keras .metrics .Metric ):
537541 """
538542 Computes perplexity, defined as e^(categorical crossentropy).
0 commit comments