Skip to content

Commit aa9281d

Browse files
Update generative-proof-of-concept-CPU-preprocessing-in-memory.py
Try getting the tensor spec correct ...
1 parent 51dc82e commit aa9281d

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

generative-proof-of-concept-CPU-preprocessing-in-memory.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,8 +1355,10 @@ def __next__(self):
13551355
self._expand_next_batch()
13561356

13571357
# Pop and return one sample
1358-
input_sample = [self.data.pop(0)] # Nested as per model input spec
1359-
label_sample = [self.labels.pop(0)] # Nested as per model output spec
1358+
# input_sample = [self.data.pop(0)] # Nested as per model input spec
1359+
# label_sample = [self.labels.pop(0)] # Nested as per model output spec
1360+
input_sample = [self.data.pop(0)]
1361+
label_sample = [self.labels.pop(0)]
13601362

13611363
return (input_sample, label_sample)
13621364

@@ -1368,8 +1370,10 @@ def create_dataset(raw_text_samples, tokenizer, sample_expansion_batch_size=10)
13681370
dataset = tf.data.Dataset.from_generator(
13691371
lambda: generator,
13701372
output_signature=(
1371-
tf.TensorSpec(shape=(1, MAX_SEQ_LENGTH), dtype=tf.int32), # Nested input
1372-
tf.TensorSpec(shape=(1, VOCABULARY_SIZE), dtype=tf.float32) # Nested one-hot label
1373+
tf.TensorSpec(shape=(MAX_SEQ_LENGTH,), dtype=tf.float32),
1374+
tf.TensorSpec(shape=(VOCABULARY_SIZE,), dtype=tf.float32)
1375+
# tf.TensorSpec(shape=(1, MAX_SEQ_LENGTH), dtype=tf.int32), # Nested input
1376+
# tf.TensorSpec(shape=(1, VOCABULARY_SIZE), dtype=tf.float32) # Nested one-hot label
13731377
)
13741378
)
13751379
# Set dataset to allow multiple epochs:

0 commit comments

Comments
 (0)