Skip to content

Commit 1f64801

Browse files
Update train_a_generative_llm.py
FIx loss
1 parent b531405 commit 1f64801

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

train_a_generative_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -445,7 +445,7 @@
445445
p_lateral_connection_decay=zero_95_exp_decay,
446446
num_lateral_connection_tries_per_unit=num_lateral_connection_tries_per_unit,
447447
learning_rate=learning_rate,
448-
loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
448+
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
449449
metrics=[tf.keras.metrics.SparseCategoricalAccuracy(),
450450
sparse_perplexity_metric,
451451
# tf.keras.metrics.Accuracy()

0 commit comments

Comments
 (0)