Skip to content

Commit 9303909

Browse files
Update train_a_generative_llm.py
Replace stage 1-a perplexity with its sparse analogue.
1 parent 4d09a07 commit 9303909

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

train_a_generative_llm.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from cerebrosllmutils.llm_utils import (
2222
prepare_data,
2323
InterleavedRoPE,
24-
Perplexity,
24+
SparsePerplexity,
2525
GatedMergeLayer,
2626
ChunkedAttentionBlock,
2727
MambaBlock,
@@ -409,7 +409,7 @@
409409
meta_trial_number = 42 # irrelevant unless in distributed training
410410

411411
# Custom metric: Perplexity:
412-
perplexity_metric = Perplexity()
412+
sparse_perplexity_metric = SparsePerplexity
413413

414414
cerebros_automl = SimpleCerebrosRandomSearch(
415415
unit_type=DenseUnit,
@@ -447,7 +447,7 @@
447447
learning_rate=learning_rate,
448448
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
449449
metrics=[tf.keras.metrics.SparseCategoricalAccuracy(),
450-
perplexity_metric, # Need to fix...
450+
sparse_perplexity_metric,
451451
# tf.keras.metrics.Accuracy()
452452
],
453453
epochs=epochs,

0 commit comments

Comments
 (0)