Skip to content

Commit c93633e

Browse files
Update train_a_generative_llm.py
Fix type float32 <> int32, fix output sig for the tf.data.Dataset.
1 parent 1052d32 commit c93633e

File tree

1 file changed

+5
-4
lines changed

1 file changed

+5
-4
lines changed

train_a_generative_llm.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -255,15 +255,15 @@
255255

256256
x_train_tf = tf.constant(X_train, tf.int32)
257257
print(x_train_tf)
258-
y_train_tf = tf.constant(y_train, tf.float32)
258+
y_train_tf = tf.constant(y_train, tf.int32)
259259
print(y_train_tf)
260260

261261
x_train_packaged = [x_train_tf]
262262
y_train_packaged = [y_train_tf]
263263

264264
x_test_tf = tf.constant(X_test, tf.int32)
265-
y_test_tf = tf.constant(y_test, tf.float32)
266-
265+
y_test_tf = tf.constant(y_test, tf.int32)
266+
267267
x_test_packaged = [x_test_tf]
268268
y_test_packaged = [y_test_tf]
269269

@@ -819,7 +819,8 @@ def create_dataset(raw_text_samples, tokenizer, sample_expansion_batch_size=50,
819819
# )
820820
output_signature=(
821821
(tf.TensorSpec(shape=(generator_0.max_seq_length,), dtype=tf.int32),), # A tuple containing ONE TensorSpec
822-
tf.TensorSpec(shape=(generator_0.vocabulary_size,), dtype=tf.float32) # A single TensorSpec
822+
tf.TensorSpec(shape=(), dtype=tf.int32)
823+
# tf.TensorSpec(shape=(generator_0.vocabulary_size,), dtype=tf.float32) # A single TensorSpec
823824
)
824825
)
825826

0 commit comments

Comments
 (0)