Skip to content

Commit a16d669

Browse files
Reworked some superficial comments
1 parent 99492e0 commit a16d669

File tree

1 file changed

+11
-15
lines changed

1 file changed

+11
-15
lines changed

keras/src/quantizers/gptqutils.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,26 +23,26 @@ def get_dataloader(tokenizer, seqlen, dataset, nsamples=128, seed=0):
2323
all_tokens = []
2424
rng = np.random.default_rng(seed=42)
2525

26-
# --- Step 1: Unify all input types into a single list of tokens ---
26+
# Unify all input types into a single list of tokens
2727
if isinstance(dataset, str):
2828
logging.info(f"Loading '{dataset}' dataset from Hub...")
2929
if dataset == "wikitext2":
3030
d_name, d_config = "wikitext", "wikitext-2-raw-v1"
3131
elif dataset == "ptb":
3232
url = "https://www.fit.vutbr.cz/~imikolov/rnnlm/simple-examples.tgz"
3333
try:
34-
# 1. Download the archive into memory
34+
# Download the archive into memory
3535
response = requests.get(url)
3636
response.raise_for_status()
3737

38-
# 2. Extract only the test file from the in-memory archive
38+
# Extract only the test file from the in-memory archive
3939
with tarfile.open(
4040
fileobj=io.BytesIO(response.content), mode="r:gz"
4141
) as tar:
4242
train_path = "./simple-examples/data/ptb.train.txt"
4343
test_bytes = tar.extractfile(train_path).read()
4444

45-
# 3. Decode the bytes and join into a single string
45+
# Decode the bytes and join into a single string
4646
test_lines = test_bytes.decode("utf-8").strip().split("\n")
4747
full_text = "\n\n".join(test_lines)
4848
all_tokens = tokenizer.tokenize(full_text)
@@ -51,7 +51,7 @@ def get_dataloader(tokenizer, seqlen, dataset, nsamples=128, seed=0):
5151
"calibration."
5252
)
5353

54-
# 2. Perform sampling and chunking directly inside this block
54+
# Perform sampling and chunking directly inside this block
5555
all_tokens = np.array(all_tokens, dtype=np.int32)
5656
required_tokens = nsamples * seqlen
5757
if len(all_tokens) < required_tokens:
@@ -73,7 +73,7 @@ def get_dataloader(tokenizer, seqlen, dataset, nsamples=128, seed=0):
7373

7474
final_array = ops.stack(calibration_samples, axis=0)
7575

76-
# 3. Return the correctly shaped array, isolating the logic
76+
# Return the correctly shaped array, isolating the logic
7777
return ops.convert_to_numpy(final_array)
7878

7979
except Exception as e:
@@ -115,7 +115,6 @@ def get_dataloader(tokenizer, seqlen, dataset, nsamples=128, seed=0):
115115

116116
return np.array(samples, dtype=np.int32)
117117
else:
118-
logging.info(f"Warning: No specific alias found for '{dataset}'.")
119118
logging.info(
120119
f"Attempting to load '{dataset}' directly with its "
121120
"default configuration."
@@ -132,7 +131,7 @@ def get_dataloader(tokenizer, seqlen, dataset, nsamples=128, seed=0):
132131
all_tokens = tokenizer.tokenize(full_text)
133132

134133
else:
135-
logging.info("\n==> Using pre-made dataset/generator...")
134+
logging.info("Using pre-made dataset/generator")
136135
dataset_list = list(dataset)
137136

138137
if not dataset_list:
@@ -161,9 +160,6 @@ def get_dataloader(tokenizer, seqlen, dataset, nsamples=128, seed=0):
161160
repeats = -(-required_tokens // len(all_tokens)) # Ceiling division
162161
all_tokens = np.tile(all_tokens, repeats)
163162

164-
# --- Step 3: Chunk the token list into samples ---
165-
# utils.set_random_seed(seed)
166-
167163
calibration_samples = []
168164
for _ in range(nsamples):
169165
# Generate a random starting index
@@ -260,7 +256,7 @@ def apply_gptq_layerwise(
260256
embedding_layer = None
261257
transformer_blocks = []
262258
if hasattr(model, "backbone"):
263-
logging.info(" -> Detected KerasNLP model structure.")
259+
logging.info("Detected KerasNLP model structure.")
264260
backbone = model.backbone
265261
transformer_blocks = backbone.transformer_layers
266262
# Find the embedding layer by checking for common names or by type.
@@ -311,7 +307,7 @@ def apply_gptq_layerwise(
311307
"Skipping."
312308
)
313309
else:
314-
logging.info(f" Found layers: {list(sub_layers_map.keys())}")
310+
logging.info(f"Found layers: {list(sub_layers_map.keys())}")
315311
gptq_objects = {
316312
name: GPTQ(layer) for name, layer in sub_layers_map.items()
317313
}
@@ -397,7 +393,7 @@ def quantize_model(model, config):
397393
"""
398394
logging.info("Starting GPTQ quantization process...")
399395

400-
# 1. Load ALL data needed from the generator/source in a single call.
396+
# Load ALL data needed from the generator/source in a single call.
401397
total_samples_to_request = config.nsamples
402398
full_dataloader = get_dataloader(
403399
config.tokenizer,
@@ -406,7 +402,7 @@ def quantize_model(model, config):
406402
nsamples=total_samples_to_request,
407403
)
408404

409-
# 2. Split the materialized data. This works because full_dataloader
405+
# Split the materialized data. This works because full_dataloader
410406
# is now a NumPy array, which can be sliced and reused.
411407
calibration_dataloader = full_dataloader[: config.nsamples]
412408

0 commit comments

Comments
 (0)