Skip to content

Commit 1095a5c

Browse files
committed
Revert quantization additions to something that works on CUDA still
1 parent 2c33914 commit 1095a5c

File tree

3 files changed

+37
-41
lines changed

3 files changed

+37
-41
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,12 +148,12 @@ python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int
148148
To generate int4 version of model
149149
```bash
150150
# Spits out model at checkpoints/$MODEL_REPO/model_int4.g32.$DEVICE.pth
151-
python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 --groupsize 32 --device $DEVICE
151+
python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 --groupsize 32
152152
```
153153

154154
To run with int4, just pass the int4 checkpoint to generate.py.
155155
```bash
156-
python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.$DEVICE.pth --compile --device $DEVICE
156+
python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --compile
157157
```
158158

159159
## Speculative Sampling

generate.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,12 +225,10 @@ def _load_model(checkpoint_path, device, precision, use_tp):
225225
if "int4" in str(checkpoint_path):
226226
print("Using int4 weight-only quantization!")
227227
path_comps = checkpoint_path.name.split(".")
228-
assert path_comps[-3].startswith("g")
229-
assert path_comps[-2] in device, "weight packed format mismatch, please rerun quantize.py!"
230-
groupsize = int(path_comps[-3][1:])
228+
groupsize = int(path_comps[-2][1:])
231229
from quantize import WeightOnlyInt4QuantHandler
232230
simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
233-
model = simple_quantizer.convert_for_runtime(use_cuda)
231+
model = simple_quantizer.convert_for_runtime()
234232

235233
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
236234
if "model" in checkpoint and "stories" in str(checkpoint_path):

quantize.py

Lines changed: 33 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,6 @@
1919

2020
from model import Transformer
2121

22-
default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
23-
2422
##### Quantization Primitives ######
2523

2624
def dynamically_quantize_per_channel(x, quant_min, quant_max, target_dtype):
@@ -328,8 +326,8 @@ def create_quantized_state_dict(self):
328326
for fqn, mod in self.mod.named_modules():
329327
if isinstance(mod, torch.nn.Linear):
330328
int8_weight, scales, _ = dynamically_quantize_per_channel(mod.weight.float(), -128, 127, torch.int8)
331-
cur_state_dict[f"{fqn}.weight"] = int8_weight.to('cpu')
332-
cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype).to('cpu')
329+
cur_state_dict[f"{fqn}.weight"] = int8_weight
330+
cur_state_dict[f"{fqn}.scales"] = scales.to(mod.weight.dtype)
333331

334332
return cur_state_dict
335333

@@ -365,9 +363,6 @@ def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_til
365363
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(weight_int32, inner_k_tiles)
366364
return weight_int4pack, scales_and_zeros
367365

368-
def _calc_padded_size(k, groupsize=1, innner_k_tiles=1):
369-
from model import find_multiple
370-
return find_multiple(k, 1024)
371366

372367
def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
373368
origin_x_size = x.size()
@@ -381,29 +376,39 @@ def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, grou
381376
def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = 1):
382377
return k % groupsize == 0 and k % (inner_k_tiles * 16) == 0
383378

384-
def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_cuda):
379+
def replace_linear_int4(module, groupsize, inner_k_tiles, padding):
385380
for name, child in module.named_children():
386381
if isinstance(child, nn.Linear):
387-
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed:
382+
if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles):
383+
setattr(module, name, WeightOnlyInt4Linear(
384+
child.in_features, child.out_features, bias=False,
385+
groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=False,
386+
))
387+
elif padding:
388388
setattr(module, name, WeightOnlyInt4Linear(
389389
child.in_features, child.out_features, bias=False,
390-
groupsize=groupsize, inner_k_tiles=inner_k_tiles, use_cuda=use_cuda
390+
groupsize=groupsize, inner_k_tiles=inner_k_tiles, padding=True,
391391
))
392392
else:
393-
replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, use_cuda)
393+
replace_linear_int4(child, groupsize, inner_k_tiles, padding)
394394

395395

396396
class WeightOnlyInt4QuantHandler:
397-
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding_allowed=True):
397+
def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
398398
self.mod = mod
399399
self.groupsize = groupsize
400400
self.inner_k_tiles = inner_k_tiles
401-
self.padding_allowed = padding_allowed
401+
self.padding = padding
402402
assert groupsize in [32, 64, 128, 256]
403403
assert inner_k_tiles in [2, 4, 8]
404404

405405
@torch.no_grad()
406-
def create_quantized_state_dict(self):
406+
def create_quantized_state_dict(self, use_cuda = True):
407+
if use_cuda:
408+
device="cuda"
409+
else:
410+
device="cpu"
411+
407412
cur_state_dict = self.mod.state_dict()
408413
for fqn, mod in self.mod.named_modules():
409414
if isinstance(mod, torch.nn.Linear):
@@ -415,7 +420,7 @@ def create_quantized_state_dict(self):
415420

416421
weight = mod.weight.data
417422
if not _check_linear_int4_k(in_features, self.groupsize, self.inner_k_tiles):
418-
if self.padding_allowed:
423+
if self.padding:
419424
from model import find_multiple
420425
import torch.nn.functional as F
421426
print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0")
@@ -426,15 +431,15 @@ def create_quantized_state_dict(self):
426431
"and that groupsize and inner_k_tiles*16 evenly divide into it")
427432
continue
428433
weight_int4pack, scales_and_zeros = prepare_int4_weight_and_scales_and_zeros(
429-
weight.to(torch.bfloat16), self.groupsize, self.inner_k_tiles
434+
weight.to(torch.bfloat16).to(device=device), self.groupsize, self.inner_k_tiles
430435
)
431436
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to('cpu')
432437
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to('cpu')
433438

434439
return cur_state_dict
435440

436-
def convert_for_runtime(self, use_cuda):
437-
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding_allowed, use_cuda)
441+
def convert_for_runtime(self):
442+
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
438443
return self.mod
439444

440445
class WeightOnlyInt4GPTQQuantHandler(GPTQQuantHandler):
@@ -458,10 +463,7 @@ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
458463
# we need to do the padding here, both for q and the qparams if necessary
459464
def make_names_and_values_dict_func(q, qparams):
460465
k = q.shape[1]
461-
if not _check_linear_int4_k(k, groupsize, inner_k_tiles):
462-
new_k = find_multiple(k, 1024)
463-
else:
464-
new_k = k
466+
new_k = find_multiple(k, 1024)
465467
# how much we need to pad the weight
466468
delta_k = new_k - q.shape[1]
467469
final_q = torch.ops.aten._convert_weight_to_int4pack(F.pad(q, pad=(0, delta_k)), inner_k_tiles)
@@ -474,8 +476,8 @@ def make_names_and_values_dict_func(q, qparams):
474476
super().__init__()
475477

476478

477-
def convert_for_runtime(self, use_cuda):
478-
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding, use_cuda)
479+
def convert_for_runtime(self):
480+
replace_linear_int4(self.mod, self.groupsize, self.inner_k_tiles, self.padding)
479481
return self.mod
480482

481483
class WeightOnlyInt4Linear(torch.nn.Module):
@@ -486,11 +488,11 @@ class WeightOnlyInt4Linear(torch.nn.Module):
486488

487489
def __init__(
488490
self, in_features: int, out_features: int,
489-
bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, use_cuda=True,
491+
bias=True, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, padding: bool = True,
490492
) -> None:
491493
super().__init__()
492-
self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles)
493-
if self.padding:
494+
self.padding = padding
495+
if padding:
494496
from model import find_multiple
495497
self.origin_in_features = in_features
496498
in_features = find_multiple(in_features, 1024)
@@ -536,9 +538,9 @@ def quantize(
536538
percdamp: float = .01,
537539
blocksize: int = 128,
538540
label: str = '',
539-
device: str = default_device,
540541
) -> None:
541542
assert checkpoint_path.is_file(), checkpoint_path
543+
542544
device = 'cpu'
543545
precision = torch.bfloat16
544546

@@ -549,8 +551,6 @@ def quantize(
549551
model = Transformer.from_name(checkpoint_path.parent.name)
550552

551553
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
552-
if "model" in checkpoint and "stories" in str(checkpoint_path):
553-
checkpoint = checkpoint["model"]
554554
model.load_state_dict(checkpoint, assign=True)
555555
model = model.to(dtype=precision, device=device)
556556

@@ -565,13 +565,12 @@ def quantize(
565565

566566
elif mode == 'int4':
567567
print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization")
568-
print(f"Prepacking model weights in {device} optimal layout")
569568
quant_handler = WeightOnlyInt4QuantHandler(model, groupsize)
570569
quantized_state_dict = quant_handler.create_quantized_state_dict()
571570

572571
dir_name = checkpoint_path.parent
573572
base_name = checkpoint_path.name
574-
new_base_name = base_name.replace('.pth', f"{label}int4.g{groupsize}.{device}.pth")
573+
new_base_name = base_name.replace('.pth', f"{label}int4.g{groupsize}.pth")
575574

576575
elif mode == 'int4-gptq':
577576
print("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization using GPTQ...")
@@ -594,7 +593,7 @@ def quantize(
594593

595594
dir_name = checkpoint_path.parent
596595
base_name = checkpoint_path.name
597-
new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.{device}.pth")
596+
new_base_name = base_name.replace('.pth', f"{label}int4-gptq.g{groupsize}.pth")
598597
else:
599598
raise ValueError(f"Invalid quantization mode {mode} needs to be one of [int8, int4, int4-gpptq]")
600599

@@ -618,7 +617,6 @@ def quantize(
618617
parser.add_argument('--percdamp', type=float, default=.01, help='gptq percentage dampening')
619618
parser.add_argument('--blocksize', type=int, default=128, help='blocksize for gptq')
620619
parser.add_argument('--label', type=str, default='_', help='label to add to output filename')
621-
parser.add_argument('--device', type=str, default=default_device, help='device to use')
622620

623621
args = parser.parse_args()
624-
quantize(args.checkpoint_path, args.mode, args.groupsize, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label, args.device)
622+
quantize(args.checkpoint_path, args.mode, args.groupsize, args.calibration_tasks, args.calibration_limit, args.calibration_seq_length, args.pad_calibration_inputs, args.percdamp, args.blocksize, args.label)

0 commit comments

Comments
 (0)