Skip to content

Commit cf14124

Browse files
committed
feat(quantize): readapt GPTQ for rwkv
1 parent e74d72a commit cf14124

File tree

1 file changed

+97
-69
lines changed

1 file changed

+97
-69
lines changed

quantize/tmp_rwkv.py

Lines changed: 97 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55

66
import os
77
import torch.nn.functional as F
8-
import torch.nn as nn
8+
from collections import OrderedDict
99
import time
1010
import math
1111
import re
1212

13+
WBITS = 8
14+
GROUPSIZE = -1
15+
1316
class GPTQ_RWKV(RWKV):
1417

1518
### begin GPTQ
@@ -29,17 +32,15 @@ def __init__(self, weight, name):
2932
self.deactivate_add_batch_call = False
3033

3134
def add_batch(self, inp):
32-
3335
# After calling fasterquant, we don't want to call add_batch anymore
3436
if self.deactivate_add_batch_call:
3537
return
3638

3739
if len(inp.shape) == 2:
3840
inp = inp.unsqueeze(0)
3941

40-
#TODO: is the case with len = 1 still necessary ?
41-
tmp = 1 if len(inp.shape) == 1 else inp.shape[0]
42-
42+
tmp = inp.shape[0]
43+
4344
# Assume weight come from nn.Linear
4445
if len(inp.shape) == 3:
4546
inp = inp.reshape((-1, inp.shape[-1]))
@@ -52,7 +53,9 @@ def add_batch(self, inp):
5253

5354
def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False):
5455
W = self.weight.data.clone()
55-
# Need to transpose here, same reason as in __init__ with self.columns
56+
# OLD: Need to transpose here, same reason as in __init__ with self.columns
57+
# UPDATE: no need to tranpose as we already transpose in my_linear()
58+
# UPDATE2: for rwkv, this is necessary
5659
W = W.t()
5760
W = W.float()
5861

@@ -63,10 +66,11 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False)
6366

6467
H = self.H
6568
del self.H
69+
6670
dead = torch.diag(H) == 0
6771
H[dead, dead] = 1
6872
W[:, dead] = 0
69-
73+
7074
if actorder:
7175
perm = torch.argsort(torch.diag(H), descending=True)
7276
W = W[:, perm]
@@ -82,6 +86,11 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False)
8286
H = torch.cholesky_inverse(H)
8387
H = torch.linalg.cholesky(H, upper=True)
8488
Hinv = H
89+
90+
g_idx = []
91+
scale = []
92+
zero = []
93+
now_idx = 1
8594

8695
for i1 in range(0, self.columns, blocksize):
8796
i2 = min(i1 + blocksize, self.columns)
@@ -101,6 +110,11 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False)
101110
if (i1 + i) % groupsize == 0:
102111
self.quantizer.find_params(W[:, (i1 + i):(i1 + i + groupsize)], weight=True)
103112

113+
if ((i1 + i) // groupsize) - now_idx == -1:
114+
scale.append(self.quantizer.scale)
115+
zero.append(self.quantizer.zero)
116+
now_idx += 1
117+
104118
q = quantize(
105119
w.unsqueeze(1), self.quantizer.scale, self.quantizer.zero, self.quantizer.maxq
106120
).flatten()
@@ -116,15 +130,27 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False)
116130

117131
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
118132

133+
119134
torch.cuda.synchronize()
120135
print('time %.2f' % (time.time() - tick))
121136
print('error', torch.sum(Losses).item())
122-
137+
138+
groupsize = groupsize if groupsize != -1 else self.columns
139+
g_idx = [i // groupsize for i in range(self.columns)]
140+
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
123141
if actorder:
124142
invperm = torch.argsort(perm)
125143
Q = Q[:, invperm]
144+
g_idx = g_idx[invperm]
126145

127146
self.weight.data = Q.reshape(self.weight.shape).to(self.weight.data.dtype)
147+
148+
if scale == []:
149+
scale.append(self.quantizer.scale)
150+
zero.append(self.quantizer.zero)
151+
scale = torch.cat(scale,dim=1)
152+
zero = torch.cat(zero,dim=1)
153+
return scale,zero,g_idx
128154

129155
### end GPTQ
130156

@@ -134,6 +160,7 @@ def __init__(self, model, strategy):
134160
for i in range(self.args.n_layer):
135161
assert self.strategy[i].device == "cpu"
136162

163+
#TODO: Change to match my implem
137164
def _fill_subset(self, layer_id):
138165
# Keep only layer within block layer_id
139166
is_weight = re.compile(f'^blocks\.{layer_id}\..*\.weight$')
@@ -146,18 +173,18 @@ def _fill_subset(self, layer_id):
146173
if is_last_layer:
147174
self.subset["head.weight"] = self.w["head.weight"]
148175

149-
176+
return self.subset
177+
150178
def alloc_gptq(self, layer_id):
151179
self.subset = {}
152180
self.gptq = {}
153181

154-
self._fill_subset(layer_id)
155-
182+
self.subset = self._fill_subset(layer_id)
183+
156184
for name in self.subset:
157185
self.gptq[name] = self.GPTQ(self.subset[name], name)
158186
self.gptq[name].quantizer = Quantizer()
159-
#TODO: add argparse to configure
160-
self.gptq[name].quantizer.configure(bits=4, perchannel=True, sym=False, mse=False, trits=False)
187+
self.gptq[name].quantizer.configure(bits=WBITS, perchannel=True, sym=False, mse=False, trits=False)
161188

162189
def free_gptq(self):
163190
self.subset = {}
@@ -166,11 +193,10 @@ def free_gptq(self):
166193
def fasterquant(self, layer_id, quantizers):
167194

168195
for name in self.subset:
169-
print(f"Quantizing {name} of layer {layer_id}")
170-
#TODO: add argparse to fastquant
171-
self.gptq[name].fasterquant(percdamp=0.01, groupsize=-1, actorder=False)
172-
# self.gptq[name].fastquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order)
173-
quantizers[name] = self.gptq[name].quantizer
196+
print(layer_id, name)
197+
print('Quantizing ...')
198+
scale,zero,g_idx = self.gptq[name].fasterquant(percdamp=0.01, groupsize=GROUPSIZE, actorder=False)
199+
quantizers[f"linear{layer_id}_w"] = (self.gptq[name].quantizer.cpu(), scale.cpu(), zero.cpu(), g_idx.cpu())
174200

175201
### end GPTQ_RWKV
176202

@@ -326,7 +352,7 @@ def forward_block(self, x, state, i, seq_mode, full_output=False):
326352
orx = self.w[f'{att}output.weight_rx'] if wtype == torch.uint8 else x
327353
omy = self.w[f'{att}output.weight_my'] if wtype == torch.uint8 else x
328354
ory = self.w[f'{att}output.weight_ry'] if wtype == torch.uint8 else x
329-
355+
330356
x, state[i*5+0], state[i*5+1], state[i*5+2], state[i*5+3] = ATT(
331357
x=x, sx=state[i*5+0], aa=state[i*5+1], bb=state[i*5+2], pp=state[i*5+3],
332358
ln_w=self.w[f'{bbb}ln1.weight'], ln_b=self.w[f'{bbb}ln1.bias'],
@@ -338,12 +364,6 @@ def forward_block(self, x, state, i, seq_mode, full_output=False):
338364
rmx=rmx, rrx=rrx, rmy=rmy, rry=rry,
339365
omx=omx, orx=orx, omy=omy, ory=ory,
340366
)
341-
342-
# Deactivate add_batch() after quantization is applied
343-
kw.deactivate_add_batch_call = True
344-
vw.deactivate_add_batch_call = True
345-
rw.deactivate_add_batch_call = True
346-
ow.deactivate_add_batch_call = True
347367

348368
if dd.stream:
349369
del kw, vw, rw, ow
@@ -378,11 +398,6 @@ def forward_block(self, x, state, i, seq_mode, full_output=False):
378398
vmx=vmx, vrx=vrx, vmy=vmy, vry=vry,
379399
rmx=rmx, rrx=rrx, rmy=rmy, rry=rry,
380400
)
381-
382-
# Deactivate add_batch() after quantization is applied
383-
kw.deactivate_add_batch_call = True
384-
vw.deactivate_add_batch_call = True
385-
rw.deactivate_add_batch_call = True
386401

387402
if dd.stream:
388403
del kw, vw, rw
@@ -392,7 +407,6 @@ def forward_block(self, x, state, i, seq_mode, full_output=False):
392407
x = x / 2
393408

394409
is_last_layer = i == (args.n_layer - 1)
395-
396410
if is_last_layer:
397411
dd = self.strategy[args.n_layer]
398412
x = x[-1,:] if (seq_mode and (not full_output)) else x
@@ -410,63 +424,77 @@ def forward_block(self, x, state, i, seq_mode, full_output=False):
410424

411425
### end RWKV
412426

413-
model = GPTQ_RWKV("./RWKV-4-Pile-169M-20220807-8023.pth", strategy='cpu fp32')
414-
415-
NSAMPLES=2
416-
HIDDEN_SIZE=model.args.n_embd
417-
SEQLEN=1024 # cf https://huggingface.co/BlinkDL/rwkv-4-pile-169m
418-
419-
# train_tokens, test_tokens = get_loaders(
420-
# dataset_name="wikitext2",
421-
# nsamples=NSAMPLES,
422-
# seed=42,
423-
# seqlen=SEQLEN,
424-
# model=model
425-
# )
426-
427-
# tokens = torch.cat([inp for inp, _ in train_tokens], dim=0)
428-
tokens = torch.zeros((NSAMPLES, SEQLEN), dtype=torch.int64)
429-
print("tokens.shape", tokens.shape)
430-
431-
is_last_layer = lambda x: x == (model.args.n_layer - 1)
432-
433-
start_time = time.time()
434-
435-
#TODO: Do the same in GPU side
436-
with torch.no_grad():
427+
@torch.no_grad()
428+
def quantize_gptq_custom(model, tokens):
429+
nsamples = tokens.shape[0]
437430
seq_mode = len(tokens) > 1
431+
is_last_layer = lambda x: x == (model.args.n_layer - 1)
432+
438433
inps = model.w['emb.weight'][tokens if seq_mode else tokens[0]]
439434
outs = torch.zeros_like(inps)
440-
441435
quantizers = {}
442-
436+
443437
for layer_id in range(model.args.n_layer):
438+
439+
print(f"Quantizing layer {layer_id} ...")
444440

445441
model.alloc_gptq(layer_id)
446442

447-
for j in range(NSAMPLES):
443+
for i in range(nsamples):
444+
#TODO: Are outs value normal ? (they look almost all the same)
448445
if not is_last_layer(layer_id):
449-
outs[j] = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode)
446+
outs[i] = model.forward_block(inps[i], state=None, i=layer_id, seq_mode=seq_mode)
450447
else:
451-
_ = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode)
452-
448+
_ = model.forward_block(inps[i], state=None, i=layer_id, seq_mode=seq_mode)
449+
450+
for gptq_layer in model.gptq.values():
451+
gptq_layer.deactivate_add_batch_call = True
452+
453+
tmp = model.w["blocks.0.att.key.weight"]
454+
453455
model.fasterquant(layer_id, quantizers)
454456

455-
for j in range(NSAMPLES):
457+
for i in range(nsamples):
456458
if not is_last_layer(layer_id):
457-
outs[j] = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode)
459+
outs[i] = model.forward_block(inps[i], state=None, i=layer_id, seq_mode=seq_mode)
458460
else:
459-
_ = model.forward_block(inps[j], state=None, i=layer_id, seq_mode=seq_mode)
460-
461+
_ = model.forward_block(inps[i], state=None, i=layer_id, seq_mode=seq_mode)
462+
463+
# Assign the quantized weights to the model
464+
for key in model.gptq.keys():
465+
model.w[key].copy_(model.gptq[key].weight)
466+
461467
model.free_gptq()
462468

463469
# We need to pass the outputs of block i as input of block i+1 (except for last block)
464470
if not is_last_layer(layer_id):
465471
inps, outs = outs, inps
466472

467-
end_time = time.time()
473+
return quantizers
474+
475+
if __name__ == "__main__":
476+
477+
model = GPTQ_RWKV("./RWKV-4-Pile-169M-20220807-8023.pth", strategy='cpu fp32')
478+
479+
NSAMPLES=2
480+
HIDDEN_SIZE=model.args.n_embd
481+
SEQLEN=1024 # cf https://huggingface.co/BlinkDL/rwkv-4-pile-169m
482+
483+
train_tokens, test_tokens = get_loaders(
484+
dataset_name="wikitext2",
485+
nsamples=NSAMPLES,
486+
seed=42,
487+
seqlen=SEQLEN,
488+
model=model
489+
)
490+
491+
tokens = torch.cat([inp for inp, _ in train_tokens], dim=0)
492+
tokens = torch.zeros((NSAMPLES, SEQLEN), dtype=torch.int64)
493+
print("tokens.shape", tokens.shape)
468494

469-
print(f"Done in {end_time - start_time:.2f} seconds")
495+
import pdb; pdb.set_trace()
496+
# quantizers = quantize_gptq_custom(model, tokens)
470497

471-
# TODO: Do something with quantizers dictionary
472-
# TODO: pack3 save model
498+
# model_pack_custom(model, quantizers, WBITS, GROUPSIZE)
499+
# torch.save(model.state_dict(), "model_quantized_custom.pt")
500+
# print("Done Custom GPTQ")

0 commit comments

Comments
 (0)