Skip to content

Commit e7ab63e

Browse files
Update generative-proof-of-concept-CPU-preprocessing-in-memory.py
Naming consistency, garbage collection.
1 parent cfd192d commit e7ab63e

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

generative-proof-of-concept-CPU-preprocessing-in-memory.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1315,7 +1315,7 @@ def test_text(test_prompt: str, max_new_tokens: int, sample_number: int, result_
13151315

13161316
# Create the Dataset Generaror:
13171317
class SampleExpansionGenerator:
1318-
def __init__(self, raw_text_samples, tokenizer, sample_expansion_batch_size=100):
1318+
def __init__(self, raw_text_samples, tokenizer, sample_expansion_batch_size=5):
13191319
self.raw_text_samples = raw_text_samples
13201320
self.tokenizer = tokenizer
13211321
self.sample_expansion_batch_size = sample_expansion_batch_size
@@ -1324,9 +1324,11 @@ def __init__(self, raw_text_samples, tokenizer, sample_expansion_batch_size=100)
13241324
self.current_index = 0
13251325

13261326
def _expand_next_batch(self):
1327+
13271328
# Determine the next meta-batch
13281329
start_idx = self.current_index
13291330
end_idx = min(start_idx + self.sample_expansion_batch_size, len(self.raw_text_samples))
1331+
collect()
13301332
if start_idx >= end_idx:
13311333
raise StopIteration("No more raw samples to process.")
13321334

@@ -1360,7 +1362,7 @@ def __next__(self):
13601362

13611363

13621364
# Create the tf.data.Dataset
1363-
def create_dataset(raw_text_sample, tokenizer, sample_expansion_batch_size=10) -> tf.data.Dataset:
1365+
def create_dataset(raw_text_samples, tokenizer, sample_expansion_batch_size=10) -> tf.data.Dataset:
13641366
generator = SampleExpansionGenerator(raw_text_samples, tokenizer, sample_expansion_batch_size)
13651367

13661368
dataset = tf.data.Dataset.from_generator(
@@ -1372,7 +1374,7 @@ def create_dataset(raw_text_sample, tokenizer, sample_expansion_batch_size=10) -
13721374
)
13731375
return dataset
13741376

1375-
phase_i_b_dataset = create_dataset(raw_text_sample=phase_i_b_samples, tokenizer=tokenizer, sample_expansion_batch_size=10)
1377+
phase_i_b_dataset = create_dataset(raw_text_samples=phase_i_b_samples, tokenizer=tokenizer, sample_expansion_batch_size=10)
13761378

13771379
# To Do: Set .fit() params <------<<<
13781380

0 commit comments

Comments
 (0)