Skip to content

Commit 1c0ebbe

Browse files
Feat (brevitas_examples/llm): Support for batched inputs in GPXQ/Qronos (#1427)
--------- Co-authored-by: Pablo Monteagudo Lago <44771380+pablomlago@users.noreply.github.com>
1 parent 26798a4 commit 1c0ebbe

File tree

7 files changed

+134
-53
lines changed

7 files changed

+134
-53
lines changed

src/brevitas_examples/llm/llm_args.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,13 @@ def create_args_parser() -> ArgumentParser:
475475
"--awq-clip",
476476
action="store_true",
477477
help="Whether to apply AWQ clipping (default: %(default)s).")
478+
479+
parser.add_argument(
480+
'--calibration-batch-size',
481+
type=int,
482+
default=1,
483+
help='Batch size for calibration data loader. (default: %(default)s).')
484+
478485
return parser
479486

480487

src/brevitas_examples/llm/llm_quant/awq/pre_quant.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def intercept_input(
113113
def apply_awq(
114114
model: nn.Module,
115115
tokenizer,
116-
calibration_loader: DatasetToDevice,
116+
calibration_dataset: DatasetToDevice,
117117
args: Namespace,
118118
auto_scale: bool = True,
119119
mse_range: bool = True,
@@ -127,7 +127,7 @@ def apply_awq(
127127
get_blocks_attribute(model) if args.gpxq_block_name is None else args.gpxq_block_name)
128128

129129
# Concatenate input_ids across the batch dimension
130-
samples = torch.cat(list(map(lambda sample: sample["input_ids"], calibration_loader)), dim=0)
130+
samples = torch.cat(list(map(lambda sample: sample["input_ids"], calibration_dataset)), dim=0)
131131

132132
first_block = blocks[0]
133133
cached_args, cached_kwargs = [], []

src/brevitas_examples/llm/llm_quant/data_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626

2727
import random
2828
from typing import Any
29+
from typing import Callable
30+
from typing import Dict
2931
from typing import Iterable
3032
from typing import List
3133
from typing import Optional
@@ -35,6 +37,7 @@
3537
import numpy as np
3638
from optimum.utils.normalized_config import NormalizedConfigManager
3739
import torch
40+
from torch.utils.data import DataLoader
3841
from transformers import AutoConfig
3942

4043
from brevitas_examples.llm.llm_quant.data import get_clm_dataset
@@ -163,5 +166,21 @@ def get_dataset_for_model(
163166
) for _ in range(num_layers))
164167

165168
data = DatasetToDevice(data, device=device)
166-
167169
return data
170+
171+
172+
def collate_fn(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
173+
kwargs = {}
174+
for curr_dict in batch:
175+
for key, value in curr_dict.items():
176+
if isinstance(value, torch.Tensor):
177+
if key not in kwargs:
178+
kwargs[key] = []
179+
kwargs[key].append(value)
180+
else:
181+
if key not in kwargs:
182+
kwargs[key] = value
183+
for key, value in kwargs.items():
184+
if isinstance(value, list) and len(value) > 0:
185+
kwargs[key] = torch.cat(kwargs[key], dim=0)
186+
return kwargs

src/brevitas_examples/llm/llm_quant/rotation_optimization.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
from dataclasses import dataclass
55
from dataclasses import field
66
import os
7+
from typing import Any
8+
from typing import Dict
79
from typing import List
810
from typing import Optional
911

@@ -13,12 +15,14 @@
1315
import torch.nn.functional as F
1416
import transformers
1517
from transformers import Trainer
18+
from transformers.data.data_collator import InputDataClass
1619
from transformers.tokenization_utils import PreTrainedTokenizerBase
1720

1821
from brevitas.graph.calibrate import quantization_status_manager
1922
from brevitas.optim.cailey_sgd import CaileySGD
2023
from brevitas.utils.parametrization_utils import extract_trainable_rotation_matrices
2124
from brevitas_examples.common.accelerate_utils.accelerate import remove_hooks
25+
from brevitas_examples.llm.llm_quant.data_utils import collate_fn
2226
from brevitas_examples.llm.llm_quant.data_utils import DatasetToDevice
2327

2428

@@ -119,28 +123,19 @@ def parse_rotation_optimization_args(extra_args: Optional[List[str]] = None) ->
119123

120124

121125
# Function to create a batch
122-
def collate_fn(kwargs_list, return_tensors="pt"):
123-
kwargs = {}
124-
for curr_dict in kwargs_list:
125-
for key, value in curr_dict.items():
126-
if isinstance(value, torch.Tensor):
127-
if key not in kwargs:
128-
kwargs[key] = []
129-
kwargs[key].append(value)
130-
else:
131-
if key not in kwargs:
132-
kwargs[key] = value
133-
for key, value in kwargs.items():
134-
if isinstance(value, list) and len(value) > 0:
135-
kwargs[key] = torch.cat(kwargs[key], dim=0)
136-
return kwargs
126+
def data_collator(kwargs_list: List[InputDataClass], return_tensors: str = "pt") -> Dict[str, Any]:
127+
assert (return_tensors == "pt") or (return_tensors is None), f"Only 'pt' is supported as a value for return_tensors. However {return_tensors} was received."
128+
return collate_fn(kwargs_list)
137129

138130

139131
def _prepare_train_dataset(train_dataset: DatasetToDevice) -> Dataset:
140132
return DatasetToDevice(
141-
data=[{
142-
"input_ids": train_datapoint["input_ids"], "labels": train_datapoint["input_ids"]}
143-
for train_datapoint in train_dataset.data],
133+
data=[
134+
{
135+
# setting "labels" to train_datapoint["input_ids"] is correct since "labels"
136+
# are just input_ids shifted by 1 and this shift is handled later on.
137+
"input_ids": train_datapoint["input_ids"],
138+
"labels": train_datapoint["input_ids"]} for train_datapoint in train_dataset.data],
144139
device=None)
145140

146141

@@ -191,7 +186,7 @@ def apply_rotation_optimization(
191186
args=training_args,
192187
train_dataset=train_dataset,
193188
eval_dataset=None,
194-
data_collator=collate_fn,
189+
data_collator=data_collator,
195190
optimizers=(optimizer, None))
196191
trainer.train()
197192
# After finishing training, set eval mode again

src/brevitas_examples/llm/main.py

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
import os
88
import pprint
99
import sys
10+
import warnings
1011

1112
import numpy as np
1213
import torch
14+
from torch.utils.data import DataLoader
1315
from transformers import AutoModelForCausalLM
1416
from transformers import AutoTokenizer
1517

@@ -40,6 +42,7 @@
4042
from brevitas_examples.llm.llm_quant.awq.pre_quant import apply_awq
4143
from brevitas_examples.llm.llm_quant.bias_corr import apply_bias_correction
4244
from brevitas_examples.llm.llm_quant.calibrate import apply_calibration
45+
from brevitas_examples.llm.llm_quant.data_utils import collate_fn
4346
from brevitas_examples.llm.llm_quant.data_utils import get_dataset_for_model
4447
from brevitas_examples.llm.llm_quant.equalize import apply_act_equalization
4548
from brevitas_examples.llm.llm_quant.equalize import apply_weight_equalization
@@ -92,7 +95,7 @@ def fused_rotation_no_fx(model, calibration_loader, args):
9295
with torch.no_grad(), rmsnorm_patch(model, model.config) as patcher:
9396
rmsnorm_classes = patcher.rmsnorm_classes
9497
with make_dynamo_compatible(model) as dynamo_comp:
95-
fx_model, guards = torch._dynamo.export(dynamo_comp.model)(**calibration_loader[0])
98+
fx_model, guards = torch._dynamo.export(dynamo_comp.model)(**next(iter(calibration_loader)))
9699
if hasattr(model, str(torch.nn.functional.scaled_dot_product_attention)):
97100
m_to_add = getattr(model, str(torch.nn.functional.scaled_dot_product_attention))
98101
fx_model.add_module(str(torch.nn.functional.scaled_dot_product_attention), m_to_add)
@@ -199,7 +202,7 @@ def model_export(model, tokenizer, ref_input, args, config=None):
199202

200203

201204
def fx_required(args):
202-
return True if args.weight_equalization or args.act_equalization == 'fx' or args.rotation == 'fx' or args.ln_affine_merge or args.convert_layernorm_to_rmsnorm or args.quant_sdpa == 'fx' else False
205+
return args.weight_equalization or args.act_equalization == 'fx' or args.rotation == 'fx' or args.ln_affine_merge or args.convert_layernorm_to_rmsnorm or args.quant_sdpa == 'fx'
203206

204207

205208
# Recursive function to unwrap equalized layers
@@ -232,9 +235,13 @@ def quantize_llm(args, extra_args=None):
232235
quant_ppl = None
233236

234237
require_fx = fx_required(args)
238+
if require_fx and args.calibration_batch_size > 1:
239+
warnings.warn(
240+
f"The provided configuration requires fx and has a batch size of {args.calibration_batch_size}.\nErrors may occur when using fx and batch_size > 1.\nIf you experience any issues try chaning the configuration to avoid using fx or to set the batch_size to 1."
241+
)
235242

236243
# Load the data for calibration and evaluation.
237-
calibration_loader = get_dataset_for_model(
244+
calibration_dataset = get_dataset_for_model(
238245
args.model,
239246
bos_preprocessing=args.bos_preprocessing,
240247
dataset_name=args.dataset,
@@ -246,7 +253,11 @@ def quantize_llm(args, extra_args=None):
246253
require_fx=require_fx and args.export_target is not None,
247254
device=None)
248255

249-
validation_loader = get_dataset_for_model(
256+
# Batched data loader to accelerate GPXQ algorithms
257+
calibration_loader = DataLoader(
258+
dataset=calibration_dataset, batch_size=args.calibration_batch_size, collate_fn=collate_fn)
259+
260+
validation_dataset = get_dataset_for_model(
250261
args.model,
251262
bos_preprocessing=args.bos_preprocessing,
252263
dataset_name=args.dataset,
@@ -262,7 +273,7 @@ def quantize_llm(args, extra_args=None):
262273
# Extra arguments should be used as training arguments for rotation optimization
263274
rot_optimization_args = parse_rotation_optimization_args(extra_args=extra_args)
264275
# Load the data for rotation optimization
265-
rot_calibration_loader = get_dataset_for_model(
276+
rot_calibration_dataset = get_dataset_for_model(
266277
args.model,
267278
bos_preprocessing=args.bos_preprocessing,
268279
dataset_name=args.dataset,
@@ -281,7 +292,7 @@ def quantize_llm(args, extra_args=None):
281292
print("Float model eval...")
282293
model = offload_model(model)
283294
float_ppl = compute_perplexity(
284-
model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer)
295+
model, validation_dataset, context_length=args.seqlen // 2, tokenizer=tokenizer)
285296
remove_hooks(model)
286297
print(f"Float perplexity ({args.dataset}): {float_ppl:.3f}")
287298

@@ -290,7 +301,7 @@ def quantize_llm(args, extra_args=None):
290301
with torch.no_grad(), rmsnorm_patch(model, model.config, enabled=args.replace_rmsnorm) as patcher:
291302
rmsnorm_classes = patcher.rmsnorm_classes
292303
with make_dynamo_compatible(model) as dynamo_comp:
293-
model, guards = torch._dynamo.export(dynamo_comp.model)(**calibration_loader[0])
304+
model, guards = torch._dynamo.export(model)(**next(iter(calibration_loader)))
294305
# Blockwise optimization does not work with FX at the moment
295306
args.gpxq_block_name = None
296307
model.eval()
@@ -317,7 +328,7 @@ def quantize_llm(args, extra_args=None):
317328
print("Inserting SDPA quantizable module")
318329
model = offload_model(model)
319330
with torch.no_grad(), functional_quantization_mode(model, {torch.nn.functional.scaled_dot_product_attention: ScaledDotProductAttention}):
320-
model(**calibration_loader[0])
331+
model(**next(iter(calibration_loader)))
321332
remove_hooks(model)
322333
elif args.quant_sdpa == 'eager':
323334
model = replace_sdpa_with_quantizable_layers(
@@ -365,7 +376,7 @@ def quantize_llm(args, extra_args=None):
365376
offload_model(model)
366377
print(f"Apply act equalization (SmoothQuant) with alpha {args.act_equalization_alpha}")
367378
if args.load_checkpoint:
368-
loader = [calibration_loader[0]]
379+
loader = [next(iter(calibration_loader))]
369380
else:
370381
loader = calibration_loader
371382
apply_act_equalization(
@@ -479,7 +490,7 @@ def quantize_llm(args, extra_args=None):
479490
apply_awq(
480491
model=model,
481492
tokenizer=tokenizer,
482-
calibration_loader=calibration_loader,
493+
calibration_dataset=calibration_dataset,
483494
args=args,
484495
auto_scale=args.awq_scale,
485496
mse_range=args.awq_clip,
@@ -522,7 +533,7 @@ def quantize_llm(args, extra_args=None):
522533
with quantization_cm:
523534
# We initialize weights scale factor
524535
with torch.no_grad():
525-
model(**calibration_loader[0])
536+
model(**next(iter(calibration_loader)))
526537

527538
if args.compile_ptq:
528539
for m in model.modules():
@@ -540,7 +551,7 @@ def quantize_llm(args, extra_args=None):
540551
apply_rotation_optimization(
541552
model=model,
542553
tokenizer=tokenizer,
543-
train_dataset=rot_calibration_loader,
554+
train_dataset=rot_calibration_dataset,
544555
training_args=rot_optimization_args,
545556
)
546557
# Remove hooks from optimization
@@ -561,18 +572,19 @@ def quantize_llm(args, extra_args=None):
561572
dtype=torch.float32)
562573
model = offload_model(model)
563574
with torch.no_grad():
564-
model(**calibration_loader[0])
575+
model(**next(iter(calibration_loader)))
565576
print("SVDQuant applied.")
566577

567578
if args.learned_round:
568579
print("Applying learned round...")
569580
if args.load_checkpoint:
570581
iters = 1
571-
loader = [calibration_loader[0]]
582+
loader = [calibration_dataset[0]]
572583
else:
573584
iters = args.learned_round_iters
574-
loader = calibration_loader
585+
loader = calibration_dataset
575586
remove_hooks(model)
587+
# TODO (pml): Fix learned round type hints
576588
apply_learned_round(
577589
model,
578590
loader,
@@ -650,18 +662,17 @@ def quantize_llm(args, extra_args=None):
650662
if args.eval and not args.no_quantize:
651663
print("Model eval...")
652664
with torch.no_grad(), quant_inference_mode(model, compile=args.compile_eval):
653-
model(**calibration_loader[0])
665+
model(**next(iter(calibration_loader)))
654666
quant_ppl = compute_perplexity(
655-
model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer)
667+
model, validation_dataset, context_length=args.seqlen // 2, tokenizer=tokenizer)
656668
print(f"Quantized perplexity ({args.dataset}): {quant_ppl:.3f}")
657669
few_shot_eval_results = dict()
658670
if args.few_shot_eval == 'lm_eval':
659671
from lm_eval import evaluator
660672
from lm_eval.models.huggingface import HFLM
661673
with torch.no_grad(), quant_inference_mode(model, compile=args.compile_eval):
662-
model(**calibration_loader[0])
674+
model(**next(iter(calibration_loader)))
663675
batch_size = 'auto' if args.few_shot_override_batch_size is None else args.few_shot_override_batch_size
664-
665676
wrapped_model = HFLM(
666677
pretrained=model, add_bos_token=True,
667678
batch_size=batch_size) # need to wrap for LLM eval
@@ -681,7 +692,7 @@ def quantize_llm(args, extra_args=None):
681692
elif args.few_shot_eval == 'lighteval':
682693

683694
with torch.no_grad(), quant_inference_mode(model, compile=args.compile_eval):
684-
model(**calibration_loader[0])
695+
model(**next(iter(calibration_loader)))
685696
remove_hooks(model)
686697

687698
from brevitas_examples.llm.eval_lighteval import run_lighteval
@@ -703,7 +714,7 @@ def quantize_llm(args, extra_args=None):
703714
print(f"Export to {args.export_target}")
704715
# Currently we always export with a float32 container to avoid float16 CPU errors
705716
model = model.to(dtype=torch.float32)
706-
model_export(model, tokenizer, calibration_loader[0], args, config)
717+
model_export(model, tokenizer, next(iter(calibration_loader)), args, config)
707718

708719
return {"float_ppl": float_ppl, "quant_ppl": quant_ppl, **few_shot_eval_results}, model
709720

0 commit comments

Comments
 (0)