55
66import os
77import torch .nn .functional as F
8- import torch . nn as nn
8+ from collections import OrderedDict
99import time
1010import math
1111import re
1212
13+ WBITS = 8
14+ GROUPSIZE = - 1
15+
1316class GPTQ_RWKV (RWKV ):
1417
1518 ### begin GPTQ
@@ -29,17 +32,15 @@ def __init__(self, weight, name):
2932 self .deactivate_add_batch_call = False
3033
3134 def add_batch (self , inp ):
32-
3335 # After calling fasterquant, we don't want to call add_batch anymore
3436 if self .deactivate_add_batch_call :
3537 return
3638
3739 if len (inp .shape ) == 2 :
3840 inp = inp .unsqueeze (0 )
3941
40- #TODO: is the case with len = 1 still necessary ?
41- tmp = 1 if len (inp .shape ) == 1 else inp .shape [0 ]
42-
42+ tmp = inp .shape [0 ]
43+
4344 # Assume weight come from nn.Linear
4445 if len (inp .shape ) == 3 :
4546 inp = inp .reshape ((- 1 , inp .shape [- 1 ]))
@@ -52,7 +53,9 @@ def add_batch(self, inp):
5253
5354 def fasterquant (self , blocksize = 128 , percdamp = .01 , groupsize = - 1 , actorder = False ):
5455 W = self .weight .data .clone ()
55- # Need to transpose here, same reason as in __init__ with self.columns
56+ # OLD: Need to transpose here, same reason as in __init__ with self.columns
57+ # UPDATE: no need to tranpose as we already transpose in my_linear()
58+ # UPDATE2: for rwkv, this is necessary
5659 W = W .t ()
5760 W = W .float ()
5861
@@ -63,10 +66,11 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False)
6366
6467 H = self .H
6568 del self .H
69+
6670 dead = torch .diag (H ) == 0
6771 H [dead , dead ] = 1
6872 W [:, dead ] = 0
69-
73+
7074 if actorder :
7175 perm = torch .argsort (torch .diag (H ), descending = True )
7276 W = W [:, perm ]
@@ -82,6 +86,11 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False)
8286 H = torch .cholesky_inverse (H )
8387 H = torch .linalg .cholesky (H , upper = True )
8488 Hinv = H
89+
90+ g_idx = []
91+ scale = []
92+ zero = []
93+ now_idx = 1
8594
8695 for i1 in range (0 , self .columns , blocksize ):
8796 i2 = min (i1 + blocksize , self .columns )
@@ -101,6 +110,11 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False)
101110 if (i1 + i ) % groupsize == 0 :
102111 self .quantizer .find_params (W [:, (i1 + i ):(i1 + i + groupsize )], weight = True )
103112
113+ if ((i1 + i ) // groupsize ) - now_idx == - 1 :
114+ scale .append (self .quantizer .scale )
115+ zero .append (self .quantizer .zero )
116+ now_idx += 1
117+
104118 q = quantize (
105119 w .unsqueeze (1 ), self .quantizer .scale , self .quantizer .zero , self .quantizer .maxq
106120 ).flatten ()
@@ -116,15 +130,27 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False)
116130
117131 W [:, i2 :] -= Err1 .matmul (Hinv [i1 :i2 , i2 :])
118132
133+
119134 torch .cuda .synchronize ()
120135 print ('time %.2f' % (time .time () - tick ))
121136 print ('error' , torch .sum (Losses ).item ())
122-
137+
138+ groupsize = groupsize if groupsize != - 1 else self .columns
139+ g_idx = [i // groupsize for i in range (self .columns )]
140+ g_idx = torch .tensor (g_idx , dtype = torch .int32 , device = Q .device )
123141 if actorder :
124142 invperm = torch .argsort (perm )
125143 Q = Q [:, invperm ]
144+ g_idx = g_idx [invperm ]
126145
127146 self .weight .data = Q .reshape (self .weight .shape ).to (self .weight .data .dtype )
147+
148+ if scale == []:
149+ scale .append (self .quantizer .scale )
150+ zero .append (self .quantizer .zero )
151+ scale = torch .cat (scale ,dim = 1 )
152+ zero = torch .cat (zero ,dim = 1 )
153+ return scale ,zero ,g_idx
128154
129155 ### end GPTQ
130156
@@ -134,6 +160,7 @@ def __init__(self, model, strategy):
134160 for i in range (self .args .n_layer ):
135161 assert self .strategy [i ].device == "cpu"
136162
163+ #TODO: Change to match my implem
137164 def _fill_subset (self , layer_id ):
138165 # Keep only layer within block layer_id
139166 is_weight = re .compile (f'^blocks\.{ layer_id } \..*\.weight$' )
@@ -146,18 +173,18 @@ def _fill_subset(self, layer_id):
146173 if is_last_layer :
147174 self .subset ["head.weight" ] = self .w ["head.weight" ]
148175
149-
176+ return self .subset
177+
150178 def alloc_gptq (self , layer_id ):
151179 self .subset = {}
152180 self .gptq = {}
153181
154- self ._fill_subset (layer_id )
155-
182+ self .subset = self . _fill_subset (layer_id )
183+
156184 for name in self .subset :
157185 self .gptq [name ] = self .GPTQ (self .subset [name ], name )
158186 self .gptq [name ].quantizer = Quantizer ()
159- #TODO: add argparse to configure
160- self .gptq [name ].quantizer .configure (bits = 4 , perchannel = True , sym = False , mse = False , trits = False )
187+ self .gptq [name ].quantizer .configure (bits = WBITS , perchannel = True , sym = False , mse = False , trits = False )
161188
162189 def free_gptq (self ):
163190 self .subset = {}
@@ -166,11 +193,10 @@ def free_gptq(self):
166193 def fasterquant (self , layer_id , quantizers ):
167194
168195 for name in self .subset :
169- print (f"Quantizing { name } of layer { layer_id } " )
170- #TODO: add argparse to fastquant
171- self .gptq [name ].fasterquant (percdamp = 0.01 , groupsize = - 1 , actorder = False )
172- # self.gptq[name].fastquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order)
173- quantizers [name ] = self .gptq [name ].quantizer
196+ print (layer_id , name )
197+ print ('Quantizing ...' )
198+ scale ,zero ,g_idx = self .gptq [name ].fasterquant (percdamp = 0.01 , groupsize = GROUPSIZE , actorder = False )
199+ quantizers [f"linear{ layer_id } _w" ] = (self .gptq [name ].quantizer .cpu (), scale .cpu (), zero .cpu (), g_idx .cpu ())
174200
175201 ### end GPTQ_RWKV
176202
@@ -326,7 +352,7 @@ def forward_block(self, x, state, i, seq_mode, full_output=False):
326352 orx = self .w [f'{ att } output.weight_rx' ] if wtype == torch .uint8 else x
327353 omy = self .w [f'{ att } output.weight_my' ] if wtype == torch .uint8 else x
328354 ory = self .w [f'{ att } output.weight_ry' ] if wtype == torch .uint8 else x
329-
355+
330356 x , state [i * 5 + 0 ], state [i * 5 + 1 ], state [i * 5 + 2 ], state [i * 5 + 3 ] = ATT (
331357 x = x , sx = state [i * 5 + 0 ], aa = state [i * 5 + 1 ], bb = state [i * 5 + 2 ], pp = state [i * 5 + 3 ],
332358 ln_w = self .w [f'{ bbb } ln1.weight' ], ln_b = self .w [f'{ bbb } ln1.bias' ],
@@ -338,12 +364,6 @@ def forward_block(self, x, state, i, seq_mode, full_output=False):
338364 rmx = rmx , rrx = rrx , rmy = rmy , rry = rry ,
339365 omx = omx , orx = orx , omy = omy , ory = ory ,
340366 )
341-
342- # Deactivate add_batch() after quantization is applied
343- kw .deactivate_add_batch_call = True
344- vw .deactivate_add_batch_call = True
345- rw .deactivate_add_batch_call = True
346- ow .deactivate_add_batch_call = True
347367
348368 if dd .stream :
349369 del kw , vw , rw , ow
@@ -378,11 +398,6 @@ def forward_block(self, x, state, i, seq_mode, full_output=False):
378398 vmx = vmx , vrx = vrx , vmy = vmy , vry = vry ,
379399 rmx = rmx , rrx = rrx , rmy = rmy , rry = rry ,
380400 )
381-
382- # Deactivate add_batch() after quantization is applied
383- kw .deactivate_add_batch_call = True
384- vw .deactivate_add_batch_call = True
385- rw .deactivate_add_batch_call = True
386401
387402 if dd .stream :
388403 del kw , vw , rw
@@ -392,7 +407,6 @@ def forward_block(self, x, state, i, seq_mode, full_output=False):
392407 x = x / 2
393408
394409 is_last_layer = i == (args .n_layer - 1 )
395-
396410 if is_last_layer :
397411 dd = self .strategy [args .n_layer ]
398412 x = x [- 1 ,:] if (seq_mode and (not full_output )) else x
@@ -410,63 +424,77 @@ def forward_block(self, x, state, i, seq_mode, full_output=False):
410424
411425 ### end RWKV
412426
413- model = GPTQ_RWKV ("./RWKV-4-Pile-169M-20220807-8023.pth" , strategy = 'cpu fp32' )
414-
415- NSAMPLES = 2
416- HIDDEN_SIZE = model .args .n_embd
417- SEQLEN = 1024 # cf https://huggingface.co/BlinkDL/rwkv-4-pile-169m
418-
419- # train_tokens, test_tokens = get_loaders(
420- # dataset_name="wikitext2",
421- # nsamples=NSAMPLES,
422- # seed=42,
423- # seqlen=SEQLEN,
424- # model=model
425- # )
426-
427- # tokens = torch.cat([inp for inp, _ in train_tokens], dim=0)
428- tokens = torch .zeros ((NSAMPLES , SEQLEN ), dtype = torch .int64 )
429- print ("tokens.shape" , tokens .shape )
430-
431- is_last_layer = lambda x : x == (model .args .n_layer - 1 )
432-
433- start_time = time .time ()
434-
435- #TODO: Do the same in GPU side
436- with torch .no_grad ():
427+ @torch .no_grad ()
428+ def quantize_gptq_custom (model , tokens ):
429+ nsamples = tokens .shape [0 ]
437430 seq_mode = len (tokens ) > 1
431+ is_last_layer = lambda x : x == (model .args .n_layer - 1 )
432+
438433 inps = model .w ['emb.weight' ][tokens if seq_mode else tokens [0 ]]
439434 outs = torch .zeros_like (inps )
440-
441435 quantizers = {}
442-
436+
443437 for layer_id in range (model .args .n_layer ):
438+
439+ print (f"Quantizing layer { layer_id } ..." )
444440
445441 model .alloc_gptq (layer_id )
446442
447- for j in range (NSAMPLES ):
443+ for i in range (nsamples ):
444+ #TODO: Are outs value normal ? (they look almost all the same)
448445 if not is_last_layer (layer_id ):
449- outs [j ] = model .forward_block (inps [j ], state = None , i = layer_id , seq_mode = seq_mode )
446+ outs [i ] = model .forward_block (inps [i ], state = None , i = layer_id , seq_mode = seq_mode )
450447 else :
451- _ = model .forward_block (inps [j ], state = None , i = layer_id , seq_mode = seq_mode )
452-
448+ _ = model .forward_block (inps [i ], state = None , i = layer_id , seq_mode = seq_mode )
449+
450+ for gptq_layer in model .gptq .values ():
451+ gptq_layer .deactivate_add_batch_call = True
452+
453+ tmp = model .w ["blocks.0.att.key.weight" ]
454+
453455 model .fasterquant (layer_id , quantizers )
454456
455- for j in range (NSAMPLES ):
457+ for i in range (nsamples ):
456458 if not is_last_layer (layer_id ):
457- outs [j ] = model .forward_block (inps [j ], state = None , i = layer_id , seq_mode = seq_mode )
459+ outs [i ] = model .forward_block (inps [i ], state = None , i = layer_id , seq_mode = seq_mode )
458460 else :
459- _ = model .forward_block (inps [j ], state = None , i = layer_id , seq_mode = seq_mode )
460-
461+ _ = model .forward_block (inps [i ], state = None , i = layer_id , seq_mode = seq_mode )
462+
463+ # Assign the quantized weights to the model
464+ for key in model .gptq .keys ():
465+ model .w [key ].copy_ (model .gptq [key ].weight )
466+
461467 model .free_gptq ()
462468
463469 # We need to pass the outputs of block i as input of block i+1 (except for last block)
464470 if not is_last_layer (layer_id ):
465471 inps , outs = outs , inps
466472
467- end_time = time .time ()
473+ return quantizers
474+
475+ if __name__ == "__main__" :
476+
477+ model = GPTQ_RWKV ("./RWKV-4-Pile-169M-20220807-8023.pth" , strategy = 'cpu fp32' )
478+
479+ NSAMPLES = 2
480+ HIDDEN_SIZE = model .args .n_embd
481+ SEQLEN = 1024 # cf https://huggingface.co/BlinkDL/rwkv-4-pile-169m
482+
483+ train_tokens , test_tokens = get_loaders (
484+ dataset_name = "wikitext2" ,
485+ nsamples = NSAMPLES ,
486+ seed = 42 ,
487+ seqlen = SEQLEN ,
488+ model = model
489+ )
490+
491+ tokens = torch .cat ([inp for inp , _ in train_tokens ], dim = 0 )
492+ tokens = torch .zeros ((NSAMPLES , SEQLEN ), dtype = torch .int64 )
493+ print ("tokens.shape" , tokens .shape )
468494
469- print (f"Done in { end_time - start_time :.2f} seconds" )
495+ import pdb ; pdb .set_trace ()
496+ # quantizers = quantize_gptq_custom(model, tokens)
470497
471- # TODO: Do something with quantizers dictionary
472- # TODO: pack3 save model
498+ # model_pack_custom(model, quantizers, WBITS, GROUPSIZE)
499+ # torch.save(model.state_dict(), "model_quantized_custom.pt")
500+ # print("Done Custom GPTQ")
0 commit comments