Skip to content

Commit f4584b4

Browse files
committed
feat(sanity-check): add pack and load for gptq implem
1 parent c8f7857 commit f4584b4

File tree

2 files changed

+92
-8
lines changed

2 files changed

+92
-8
lines changed

quantize/gptq/quant.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,22 @@ def make_quant(module, names, bits, groupsize, name=''):
147147
for name1, child in module.named_children():
148148
make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
149149

150+
def make_quant_custom(module, names, bits, groupsize, name=''):
151+
if isinstance(module, QuantLinear):
152+
return
153+
for attr in dir(module):
154+
tmp = getattr(module, attr)
155+
name1 = name + '.' + attr if name != '' else attr
156+
if name1 in names:
157+
delattr(module, attr)
158+
bias_name = attr.replace('w', 'b')
159+
setattr(module, attr, QuantLinear(bits, groupsize, tmp.shape[0], tmp.shape[1], module.w[bias_name] is not None))
160+
#TODO: No recursive
161+
# for name1, child in module.named_children():
162+
# make_quant_custom(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
163+
164+
165+
150166
class QuantLinear(nn.Module):
151167
def __init__(self, bits, groupsize, infeatures, outfeatures, bias, kernel_switch_threshold=128, is_cuda=is_cuda):
152168
super().__init__()

quantize/gptq/sanity_check_main.py

Lines changed: 76 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
import torch.nn as nn
66
import torch.optim as optim
7+
from collections import OrderedDict
78

89
from sanity_check_utils import seed_everything, MNISTloader, SimpleNet, train, evaluate, SimpleNet_V2
910
from gptq import *
@@ -34,9 +35,8 @@ def load_quant(model, checkpoint, wbits, groupsize):
3435

3536
# Don't quantize the last layer because qzeros is empty (I don't know why they create qzeros that way)
3637
# (gptq.py:L235, second dimension of qzeros is 0 because last layer is 10 for classification)
37-
for name in ["linear4"]:
38-
if name in layers:
39-
del layers[name]
38+
if "linear4" in layers:
39+
del layers["linear4"]
4040

4141
make_quant(model, layers, wbits, groupsize)
4242
model.load_state_dict(torch.load(checkpoint))
@@ -292,7 +292,7 @@ def fasterquant(self, layer_id, quantizers):
292292
print(layer_id, name)
293293
print('Quantizing ...')
294294
scale,zero,g_idx = self.gptq[name].fasterquant(percdamp=0.01, groupsize=GROUPSIZE, actorder=False)
295-
quantizers[f"linear{layer_id + 1}"] = (self.gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu())
295+
quantizers[f"linear{layer_id}_w"] = (self.gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu())
296296

297297
## end GPTQ_CUSTOM
298298

@@ -344,10 +344,61 @@ def quantize_gptq_custom(model, train_loader):
344344

345345

346346
def model_pack_custom(model, quantizers, wbits, groupsize):
347-
pass
347+
# Extract weights and bias from model
348+
is_weight = re.compile(r'^linear\d+_w$')
349+
weights, bias = OrderedDict(), OrderedDict()
350+
for name, param in model.w.items():
351+
if is_weight.match(name):
352+
weights[name] = param
353+
else:
354+
bias[name] = param
355+
356+
# Create linear layer out of weights and bias
357+
layers = {}
358+
for (w_name, w_param), (_, b_param) in zip(weights.items(), bias.items()):
359+
layers[w_name] = nn.Linear(w_param.shape[1], w_param.shape[0], bias=True)
360+
layers[w_name].weight.data = w_param
361+
layers[w_name].bias.data = b_param
362+
363+
make_quant_custom(model, quantizers, wbits, groupsize)
364+
qlayers = find_layers(model, [QuantLinear])
365+
366+
print('Packing ...')
367+
for name in qlayers:
368+
print(name)
369+
quantizers[name],scale,zero,g_idx = quantizers[name]
370+
qlayers[name].pack(layers[name], scale, zero, g_idx)
371+
print('Done.')
372+
return model
373+
374+
def load_quant_custom(model, checkpoint, wbits, groupsize):
375+
print('Loading model ...')
376+
model = model.eval()
377+
# Extract weights and bias from model
378+
is_weight = re.compile(r'^linear\d+_w$')
379+
weights, bias = OrderedDict(), OrderedDict()
380+
for name, param in model.w.items():
381+
if is_weight.match(name):
382+
weights[name] = param
383+
else:
384+
bias[name] = param
385+
386+
# Create linear layer out of weights and bias
387+
layers = {}
388+
for (w_name, w_param), (_, b_param) in zip(weights.items(), bias.items()):
389+
layers[w_name] = nn.Linear(w_param.shape[1], w_param.shape[0], bias=True)
390+
layers[w_name].weight.data = w_param
391+
layers[w_name].bias.data = b_param
392+
393+
# Don't quantize the last layer because qzeros is empty (I don't know why they create qzeros that way)
394+
# (gptq.py:L235, second dimension of qzeros is 0 because last layer is 10 for classification)
395+
if "linear3_w" in layers:
396+
del layers["linear3_w"]
397+
make_quant_custom(model, layers, wbits, groupsize)
398+
model.load_state_dict(torch.load(checkpoint))
399+
print('Done.')
400+
return model
348401

349-
def load_quant_custom(model, quantizers, wbits, groupsize):
350-
pass
351402

352403
def assert_parameters(model, model_custom):
353404
is_weight = re.compile(r'^linear\d+.weight$')
@@ -371,6 +422,7 @@ def assert_parameters(model, model_custom):
371422
parser.add_argument("--eval_gptq", action="store_true")
372423
parser.add_argument("--train_custom", action="store_true")
373424
parser.add_argument("--gptq_custom", action="store_true")
425+
parser.add_argument("--eval_gptq_custom", action="store_true")
374426
parser.add_argument("--pyquant", action="store_true")
375427

376428
args = parser.parse_args()
@@ -381,7 +433,9 @@ def assert_parameters(model, model_custom):
381433
criterion = nn.CrossEntropyLoss()
382434
train_loader, _, _ = MNISTloader(train_val_split=0.95).load()
383435

384-
#TODO: Do Custom packing
436+
#TODO: Do custom eval gptq
437+
#TODO: Is reference GPTQ quantizing bias as well ?
438+
#TODO: Add seed everywhere in GPT for reproducibility
385439

386440
## ================== REFERENCE ==================
387441
if args.train:
@@ -430,6 +484,20 @@ def assert_parameters(model, model_custom):
430484
model_pack_custom(model, quantizers, WBITS, GROUPSIZE)
431485
torch.save(model.state_dict(), "model_quantized_custom.pt")
432486
print("Done Custom GPTQ")
487+
elif args.eval_gptq_custom:
488+
model = GPTQ_CUSTOM("./model_custom.pt")
489+
device = torch.device("cuda:0")
490+
model = load_quant_custom(model, "model_quantized_custom.pt", WBITS, GROUPSIZE)
491+
model = model.to(device)
492+
493+
#TODO: Fix eval
494+
# start = time.time()
495+
# val_loss, val_acc = evaluate(device, model, criterion, train_loader)
496+
# end = time.time()
497+
498+
# print(f"wbits = {WBITS} using {device}")
499+
# print(f"val_loss: {val_loss:.3f} \t val_acc: {val_acc:.3f}")
500+
# print(f"Latency: {end - start}")
433501
## ================== MISC ==================
434502
elif args.pyquant:
435503
# Baseline post-training quantization from Pytorch

0 commit comments

Comments
 (0)