1
- import io
2
1
import random
3
- import tarfile
4
2
5
3
import numpy as np
6
- import requests
7
4
from absl import logging
8
- from datasets import load_dataset
9
5
10
6
from keras .src import ops
11
7
from keras .src .layers import Dense
@@ -20,136 +16,34 @@ def get_dataloader(tokenizer, seqlen, dataset, nsamples=128, seed=0):
20
16
Prepares and chunks the calibration dataloader, repeating short datasets.
21
17
"""
22
18
all_tokens = []
23
- rng = np .random .default_rng (seed = 42 )
24
19
25
- # Unify all input types into a single list of tokens
26
20
if isinstance (dataset , str ):
27
- logging .info (f"Loading '{ dataset } ' dataset from Hub..." )
28
- if dataset == "wikitext2" :
29
- d_name , d_config = "wikitext" , "wikitext-2-raw-v1"
30
- elif dataset == "ptb" :
31
- url = "https://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz"
32
- try :
33
- # Download the archive into memory
34
- response = requests .get (url )
35
- response .raise_for_status ()
36
-
37
- # Extract only the test file from the in-memory archive
38
- with tarfile .open (
39
- fileobj = io .BytesIO (response .content ), mode = "r:gz"
40
- ) as tar :
41
- train_path = "./simple-examples/data/ptb.train.txt"
42
- test_bytes = tar .extractfile (train_path ).read ()
43
-
44
- # Decode the bytes and join into a single string
45
- test_lines = test_bytes .decode ("utf-8" ).strip ().split ("\n " )
46
- full_text = "\n \n " .join (test_lines )
47
- all_tokens = tokenizer .tokenize (full_text )
48
- logging .info (
49
- "✅ Successfully processed PTB training data for"
50
- "calibration."
51
- )
21
+ raise TypeError (
22
+ "The `dataset` argument must be an iterable (e.g., a list or "
23
+ "generator) of strings or pre-tokenized tensors. Loading "
24
+ "datasets by name is no longer supported."
25
+ )
52
26
53
- # Perform sampling and chunking directly inside this block
54
- all_tokens = np .array (all_tokens , dtype = np .int32 )
55
- required_tokens = nsamples * seqlen
56
- if len (all_tokens ) < required_tokens :
57
- logging .info (
58
- f"Warning: PTB dataset is too short ({ len (all_tokens )} "
59
- "tokens). Repeating data."
60
- )
61
- repeats = - (- required_tokens // len (all_tokens ))
62
- all_tokens = np .tile (all_tokens , repeats )
63
-
64
- calibration_samples = []
65
- for _ in range (nsamples ):
66
- start_index = rng .integers (
67
- low = 0 , high = len (all_tokens ) - seqlen
68
- )
69
- end_index = start_index + seqlen
70
- sample = all_tokens [start_index :end_index ]
71
- calibration_samples .append (ops .reshape (sample , (1 , seqlen )))
72
-
73
- final_array = ops .stack (calibration_samples , axis = 0 )
74
-
75
- # Return the correctly shaped array, isolating the logic
76
- return ops .convert_to_numpy (final_array )
77
-
78
- except Exception as e :
79
- logging .info (f"Failed to download or process PTB data: { e !r} " )
80
- raise e
81
- elif dataset == "c4" :
82
- logging .info (
83
- " -> Using memory-efficient streaming strategy for C4."
84
- )
85
- streaming_dataset = load_dataset (
86
- "allenai/c4" , "en" , split = "train" , streaming = True
87
- )
88
- dataset_head = streaming_dataset .take (nsamples * 5 )
89
-
90
- samples = []
91
- docs_for_sampling = list (dataset_head )
92
-
93
- for _ in range (nsamples ):
94
- while True :
95
- doc = random .choice (docs_for_sampling )
96
- try :
97
- # Call the tokenizer layer directly (the KerasNLP way)
98
- # and squeeze the output to a 1D array.
99
- tokenized_doc = np .squeeze (tokenizer (doc ["text" ]))
100
- if len (tokenized_doc ) >= seqlen :
101
- break
102
- except Exception :
103
- docs_for_sampling .remove (doc )
104
- if not docs_for_sampling :
105
- raise ValueError (
106
- "Could not find enough valid documents"
107
- "in the C4 sample."
108
- )
109
- continue
110
-
111
- j = rng .integers (low = 0 , high = len (tokenized_doc ) - seqlen )
112
- sample_slice = tokenized_doc [j : j + seqlen ]
113
- samples .append (np .reshape (sample_slice , (1 , seqlen )))
114
-
115
- return np .array (samples , dtype = np .int32 )
116
- else :
117
- logging .info (
118
- f"Attempting to load '{ dataset } ' directly with its "
119
- "default configuration."
120
- )
121
- d_name = dataset
122
- d_config = None # Use the default configuration for the dataset
27
+ logging .info ("\n ==> Using pre-made dataset/generator..." )
28
+ dataset_list = list (dataset )
123
29
124
- # Default to "text" for wikitext2 and other datasets
125
- text_column = "text"
30
+ if not dataset_list :
31
+ raise ValueError ( "Provided dataset is empty." )
126
32
127
- raw_dataset = load_dataset ( d_name , d_config , split = "train" )
128
- text_list = [ d [ text_column ] for d in raw_dataset ]
129
- full_text = "\n \n " .join (text_list )
33
+ if isinstance ( dataset_list [ 0 ], str ):
34
+ logging . info ( " (Dataset contains strings, tokenizing now...)" )
35
+ full_text = "\n \n " .join (dataset_list )
130
36
all_tokens = tokenizer .tokenize (full_text )
131
-
132
37
else :
133
- logging .info ("Using pre-made dataset/generator" )
134
- dataset_list = list (dataset )
135
-
136
- if not dataset_list :
137
- raise ValueError ("Provided dataset is empty." )
138
-
139
- if isinstance (dataset_list [0 ], str ):
140
- logging .info (" (Dataset contains strings, tokenizing now...)" )
141
- full_text = "\n \n " .join (dataset_list )
142
- all_tokens = tokenizer .tokenize (full_text )
143
- else :
144
- logging .info (" (Dataset is pre-tokenized, concatenating...)" )
145
- concatenated_tokens = ops .concatenate (
146
- [ops .reshape (s , [- 1 ]) for s in dataset_list ], axis = 0
147
- )
148
- all_tokens = ops .convert_to_numpy (concatenated_tokens )
38
+ logging .info (" (Dataset is pre-tokenized, concatenating...)" )
39
+ concatenated_tokens = ops .concatenate (
40
+ [ops .reshape (s , [- 1 ]) for s in dataset_list ], axis = 0
41
+ )
42
+ all_tokens = ops .convert_to_numpy (concatenated_tokens )
149
43
150
44
all_tokens = np .array (all_tokens , dtype = np .int32 )
151
45
152
- # --- Step 2: Repeat data if it's too short ---
46
+ # Repeat data if it's too short
153
47
required_tokens = nsamples * seqlen
154
48
if len (all_tokens ) < required_tokens :
155
49
logging .info (
@@ -159,10 +53,12 @@ def get_dataloader(tokenizer, seqlen, dataset, nsamples=128, seed=0):
159
53
repeats = - (- required_tokens // len (all_tokens )) # Ceiling division
160
54
all_tokens = np .tile (all_tokens , repeats )
161
55
56
+ # Chunk the token list into samples
57
+
162
58
calibration_samples = []
163
59
for _ in range (nsamples ):
164
60
# Generate a random starting index
165
- start_index = rng . integers ( low = 0 , high = len (all_tokens ) - seqlen )
61
+ start_index = random . randint ( 0 , len (all_tokens ) - seqlen - 1 )
166
62
end_index = start_index + seqlen
167
63
sample = all_tokens [start_index :end_index ]
168
64
calibration_samples .append (ops .reshape (sample , (1 , seqlen )))
0 commit comments