Skip to content

Commit 921af13

Browse files
Removed the huggingfce dependency
1 parent bd57ff1 commit 921af13

File tree

2 files changed

+21
-126
lines changed

2 files changed

+21
-126
lines changed

keras/src/quantizers/gptqutils.py

Lines changed: 21 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
1-
import io
21
import random
3-
import tarfile
42

53
import numpy as np
6-
import requests
74
from absl import logging
8-
from datasets import load_dataset
95

106
from keras.src import ops
117
from keras.src.layers import Dense
@@ -20,136 +16,34 @@ def get_dataloader(tokenizer, seqlen, dataset, nsamples=128, seed=0):
2016
Prepares and chunks the calibration dataloader, repeating short datasets.
2117
"""
2218
all_tokens = []
23-
rng = np.random.default_rng(seed=42)
2419

25-
# Unify all input types into a single list of tokens
2620
if isinstance(dataset, str):
27-
logging.info(f"Loading '{dataset}' dataset from Hub...")
28-
if dataset == "wikitext2":
29-
d_name, d_config = "wikitext", "wikitext-2-raw-v1"
30-
elif dataset == "ptb":
31-
url = "https://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz"
32-
try:
33-
# Download the archive into memory
34-
response = requests.get(url)
35-
response.raise_for_status()
36-
37-
# Extract only the test file from the in-memory archive
38-
with tarfile.open(
39-
fileobj=io.BytesIO(response.content), mode="r:gz"
40-
) as tar:
41-
train_path = "./simple-examples/data/ptb.train.txt"
42-
test_bytes = tar.extractfile(train_path).read()
43-
44-
# Decode the bytes and join into a single string
45-
test_lines = test_bytes.decode("utf-8").strip().split("\n")
46-
full_text = "\n\n".join(test_lines)
47-
all_tokens = tokenizer.tokenize(full_text)
48-
logging.info(
49-
"✅ Successfully processed PTB training data for"
50-
"calibration."
51-
)
21+
raise TypeError(
22+
"The `dataset` argument must be an iterable (e.g., a list or "
23+
"generator) of strings or pre-tokenized tensors. Loading "
24+
"datasets by name is no longer supported."
25+
)
5226

53-
# Perform sampling and chunking directly inside this block
54-
all_tokens = np.array(all_tokens, dtype=np.int32)
55-
required_tokens = nsamples * seqlen
56-
if len(all_tokens) < required_tokens:
57-
logging.info(
58-
f"Warning: PTB dataset is too short ({len(all_tokens)}"
59-
"tokens). Repeating data."
60-
)
61-
repeats = -(-required_tokens // len(all_tokens))
62-
all_tokens = np.tile(all_tokens, repeats)
63-
64-
calibration_samples = []
65-
for _ in range(nsamples):
66-
start_index = rng.integers(
67-
low=0, high=len(all_tokens) - seqlen
68-
)
69-
end_index = start_index + seqlen
70-
sample = all_tokens[start_index:end_index]
71-
calibration_samples.append(ops.reshape(sample, (1, seqlen)))
72-
73-
final_array = ops.stack(calibration_samples, axis=0)
74-
75-
# Return the correctly shaped array, isolating the logic
76-
return ops.convert_to_numpy(final_array)
77-
78-
except Exception as e:
79-
logging.info(f"Failed to download or process PTB data: {e!r}")
80-
raise e
81-
elif dataset == "c4":
82-
logging.info(
83-
" -> Using memory-efficient streaming strategy for C4."
84-
)
85-
streaming_dataset = load_dataset(
86-
"allenai/c4", "en", split="train", streaming=True
87-
)
88-
dataset_head = streaming_dataset.take(nsamples * 5)
89-
90-
samples = []
91-
docs_for_sampling = list(dataset_head)
92-
93-
for _ in range(nsamples):
94-
while True:
95-
doc = random.choice(docs_for_sampling)
96-
try:
97-
# Call the tokenizer layer directly (the KerasNLP way)
98-
# and squeeze the output to a 1D array.
99-
tokenized_doc = np.squeeze(tokenizer(doc["text"]))
100-
if len(tokenized_doc) >= seqlen:
101-
break
102-
except Exception:
103-
docs_for_sampling.remove(doc)
104-
if not docs_for_sampling:
105-
raise ValueError(
106-
"Could not find enough valid documents"
107-
"in the C4 sample."
108-
)
109-
continue
110-
111-
j = rng.integers(low=0, high=len(tokenized_doc) - seqlen)
112-
sample_slice = tokenized_doc[j : j + seqlen]
113-
samples.append(np.reshape(sample_slice, (1, seqlen)))
114-
115-
return np.array(samples, dtype=np.int32)
116-
else:
117-
logging.info(
118-
f"Attempting to load '{dataset}' directly with its "
119-
"default configuration."
120-
)
121-
d_name = dataset
122-
d_config = None # Use the default configuration for the dataset
27+
logging.info("\n==> Using pre-made dataset/generator...")
28+
dataset_list = list(dataset)
12329

124-
# Default to "text" for wikitext2 and other datasets
125-
text_column = "text"
30+
if not dataset_list:
31+
raise ValueError("Provided dataset is empty.")
12632

127-
raw_dataset = load_dataset(d_name, d_config, split="train")
128-
text_list = [d[text_column] for d in raw_dataset]
129-
full_text = "\n\n".join(text_list)
33+
if isinstance(dataset_list[0], str):
34+
logging.info(" (Dataset contains strings, tokenizing now...)")
35+
full_text = "\n\n".join(dataset_list)
13036
all_tokens = tokenizer.tokenize(full_text)
131-
13237
else:
133-
logging.info("Using pre-made dataset/generator")
134-
dataset_list = list(dataset)
135-
136-
if not dataset_list:
137-
raise ValueError("Provided dataset is empty.")
138-
139-
if isinstance(dataset_list[0], str):
140-
logging.info(" (Dataset contains strings, tokenizing now...)")
141-
full_text = "\n\n".join(dataset_list)
142-
all_tokens = tokenizer.tokenize(full_text)
143-
else:
144-
logging.info(" (Dataset is pre-tokenized, concatenating...)")
145-
concatenated_tokens = ops.concatenate(
146-
[ops.reshape(s, [-1]) for s in dataset_list], axis=0
147-
)
148-
all_tokens = ops.convert_to_numpy(concatenated_tokens)
38+
logging.info(" (Dataset is pre-tokenized, concatenating...)")
39+
concatenated_tokens = ops.concatenate(
40+
[ops.reshape(s, [-1]) for s in dataset_list], axis=0
41+
)
42+
all_tokens = ops.convert_to_numpy(concatenated_tokens)
14943

15044
all_tokens = np.array(all_tokens, dtype=np.int32)
15145

152-
# --- Step 2: Repeat data if it's too short ---
46+
# Repeat data if it's too short
15347
required_tokens = nsamples * seqlen
15448
if len(all_tokens) < required_tokens:
15549
logging.info(
@@ -159,10 +53,12 @@ def get_dataloader(tokenizer, seqlen, dataset, nsamples=128, seed=0):
15953
repeats = -(-required_tokens // len(all_tokens)) # Ceiling division
16054
all_tokens = np.tile(all_tokens, repeats)
16155

56+
# Chunk the token list into samples
57+
16258
calibration_samples = []
16359
for _ in range(nsamples):
16460
# Generate a random starting index
165-
start_index = rng.integers(low=0, high=len(all_tokens) - seqlen)
61+
start_index = random.randint(0, len(all_tokens) - seqlen - 1)
16662
end_index = start_index + seqlen
16763
sample = all_tokens[start_index:end_index]
16864
calibration_samples.append(ops.reshape(sample, (1, seqlen)))

requirements-common.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,3 @@ onnxscript<=0.3.1
2929
openvino
3030
# for grain_dataset_adapter_test.py
3131
grain
32-
datasets

0 commit comments

Comments
 (0)