77import torch .nn .functional as F
88import torch .nn as nn
99import time
10- import gc
1110import math
1211import re
1312
@@ -132,27 +131,21 @@ def fasterquant(self, blocksize=128, percdamp=.01, groupsize=-1, actorder=False)
132131 ### begin GPTQ_RWKV
133132 def __init__ (self , model , strategy ):
134133 super ().__init__ (model , strategy )
135- #TODO: add assert to only quantize in CPU FP32 mode
134+ for i in range (self .args .n_layer ):
135+ assert self .strategy [i ].device == "cpu"
136136
137137 def _fill_subset (self , layer_id ):
138138 # Keep only layer within block layer_id
139- dd = self .strategy [layer_id ]
140- dev = dd .device
141-
142- for name in self .w .keys ():
143- if re .match (f'^blocks\.{ layer_id } \..*\.weight$' , name ):
144- tensor = self .w [name ]
145-
146- #TODO: Skip 1D tensors for now
147- if len (tensor .shape ) == 1 :
148- continue
149-
150- print (f"{ name } = { self .w [name ].shape } " )
151-
152- if re .match (f'^blocks\.{ layer_id } \.(?:att|ffn)\.(?:key|value|output|receptance)\.weight$' , name ):
153- tensor = tensor .to (device = dev , non_blocking = True )
154-
155- self .subset [name ] = tensor
139+ is_weight = re .compile (f'^blocks\.{ layer_id } \..*\.weight$' )
140+ for name in self .w .keys ():
141+ if is_weight .match (name ):
142+ if len (self .w [name ].shape ) == 1 : continue #TODO: Skip 1D tensors for now
143+ self .subset [name ] = self .w [name ]
144+
145+ is_last_layer = (layer_id == self .args .n_layer - 1 )
146+ if is_last_layer :
147+ self .subset ["head.weight" ] = self .w ["head.weight" ]
148+
156149
157150 def alloc_gptq (self , layer_id ):
158151 self .subset = {}
@@ -178,7 +171,6 @@ def fasterquant(self, layer_id, quantizers):
178171 self .gptq [name ].fasterquant (percdamp = 0.01 , groupsize = - 1 , actorder = False )
179172 # self.gptq[name].fastquant(percdamp=args.percdamp, groupsize=args.groupsize, actorder=args.act_order)
180173 quantizers [name ] = self .gptq [name ].quantizer
181- # TODO: may be free gptq here to save memory
182174
183175 ### end GPTQ_RWKV
184176
@@ -272,7 +264,7 @@ def ffn_seq(self, x, sx, ln_w, ln_b, k_mix, r_mix, kw, vw, rw, kmx, krx, kmy, kr
272264 vw .add_batch (vx )
273265 return x + out , xx [- 1 ,:]
274266
275- def forward_block (self , x , state , i , seq_mode , is_last_layer , full_output = False ):
267+ def forward_block (self , x , state , i , seq_mode , full_output = False ):
276268 with torch .no_grad ():
277269 args = self .args
278270
@@ -312,6 +304,12 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False)
312304 rw = self .gptq [f'{ att } receptance.weight' ]
313305 ow = self .gptq [f'{ att } output.weight' ]
314306
307+ if dd .stream :
308+ kw = kw .to (device = dev , non_blocking = True )
309+ vw = vw .to (device = dev , non_blocking = True )
310+ rw = rw .to (device = dev , non_blocking = True )
311+ ow = ow .to (device = dev , non_blocking = True )
312+
315313 kmx = self .w [f'{ att } key.weight_mx' ] if wtype == torch .uint8 else x
316314 krx = self .w [f'{ att } key.weight_rx' ] if wtype == torch .uint8 else x
317315 kmy = self .w [f'{ att } key.weight_my' ] if wtype == torch .uint8 else x
@@ -341,6 +339,7 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False)
341339 omx = omx , orx = orx , omy = omy , ory = ory ,
342340 )
343341
342+ # Deactivate add_batch() after quantization is applied
344343 kw .deactivate_add_batch_call = True
345344 vw .deactivate_add_batch_call = True
346345 rw .deactivate_add_batch_call = True
@@ -352,6 +351,7 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False)
352351 kw = self .gptq [f'{ ffn } key.weight' ]
353352 vw = self .gptq [f'{ ffn } value.weight' ]
354353 rw = self .gptq [f'{ ffn } receptance.weight' ]
354+
355355 if dd .stream :
356356 kw = kw .to (device = dev , non_blocking = True )
357357 vw = vw .to (device = dev , non_blocking = True )
@@ -391,43 +391,46 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False)
391391 if (i + 1 ) % self .RESCALE_LAYER == 0 :
392392 x = x / 2
393393
394- if is_last_layer :
394+ is_last_layer = i == (args .n_layer - 1 )
395+
396+ if is_last_layer :
395397 dd = self .strategy [args .n_layer ]
396398 x = x [- 1 ,:] if (seq_mode and (not full_output )) else x
397399 x = x .to (dtype = dd .atype , device = dd .device )
398400
399- #TODO: Add GPTQ support for head & ln_out
401+ #TODO: ln_out.weight is 1D tensor
400402 x = F .layer_norm (x , (args .n_embd ,), weight = self .w ['ln_out.weight' ], bias = self .w ['ln_out.bias' ])
403+
401404 if self .w ['head.weight' ].dtype != torch .uint8 :
402- x = x @ self .w ['head.weight' ]
403- else :
404- if seq_mode and full_output :
405- x = self .mm8_seq (x , self .w ['head.weight' ], self .w ['head.weight_mx' ], self .w ['head.weight_rx' ], self .w ['head.weight_my' ], self .w ['head.weight_ry' ])
406- else :
407- x = self .mm8_one (x , self .w ['head.weight' ], self .w ['head.weight_mx' ], self .w ['head.weight_rx' ], self .w ['head.weight_my' ], self .w ['head.weight_ry' ])
405+ x = x @ self .gptq ['head.weight' ].weight
406+ self .gptq ['head.weight' ].add_batch (x )
407+ self .gptq ['head.weight' ].deactivate_add_batch_call = True
408408
409409 return x .float ()
410410
411411 ### end RWKV
412412
413+ model = GPTQ_RWKV ("./RWKV-4-Pile-169M-20220807-8023.pth" , strategy = 'cpu fp32' )
414+
413415NSAMPLES = 2
414- HIDDEN_SIZE = 768
415- SEQLEN = 2048 # TODO: this is chosen by the model
416+ HIDDEN_SIZE = model . args . n_embd
417+ SEQLEN = 1024 # cf https://huggingface.co/BlinkDL/rwkv-4-pile-169m
416418
417419# train_tokens, test_tokens = get_loaders(
418420# dataset_name="wikitext2",
419421# nsamples=NSAMPLES,
420422# seed=42,
421423# seqlen=SEQLEN,
422- # model=None
424+ # model=model
423425# )
424426
425427# tokens = torch.cat([inp for inp, _ in train_tokens], dim=0)
426428tokens = torch .zeros ((NSAMPLES , SEQLEN ), dtype = torch .int64 )
427429print ("tokens.shape" , tokens .shape )
428430
429- model = GPTQ_RWKV ("./RWKV-4-Pile-169M-20220807-8023.pth" , strategy = 'cpu fp32' )
430- is_last_layer = [False ] * (model .args .n_layer - 1 ) + [True ]
431+ is_last_layer = lambda x : x == (model .args .n_layer - 1 )
432+
433+ start_time = time .time ()
431434
432435#TODO: Do the same in GPU side
433436with torch .no_grad ():
@@ -442,23 +445,28 @@ def forward_block(self, x, state, i, seq_mode, is_last_layer, full_output=False)
442445 model .alloc_gptq (layer_id )
443446
444447 for j in range (NSAMPLES ):
445- if not is_last_layer [ layer_id ] :
446- outs [j ] = model .forward_block (inps [j ], state = None , i = layer_id , seq_mode = seq_mode , is_last_layer = is_last_layer [ layer_id ] )
448+ if not is_last_layer ( layer_id ) :
449+ outs [j ] = model .forward_block (inps [j ], state = None , i = layer_id , seq_mode = seq_mode )
447450 else :
448- _ = model .forward_block (inps [j ], state = None , i = layer_id , seq_mode = seq_mode , is_last_layer = is_last_layer [ layer_id ] )
451+ _ = model .forward_block (inps [j ], state = None , i = layer_id , seq_mode = seq_mode )
449452
450453 model .fasterquant (layer_id , quantizers )
451454
452455 for j in range (NSAMPLES ):
453- if not is_last_layer [ layer_id ] :
454- outs [j ] = model .forward_block (inps [j ], state = None , i = layer_id , seq_mode = seq_mode , is_last_layer = is_last_layer [ layer_id ] )
456+ if not is_last_layer ( layer_id ) :
457+ outs [j ] = model .forward_block (inps [j ], state = None , i = layer_id , seq_mode = seq_mode )
455458 else :
456- _ = model .forward_block (inps [j ], state = None , i = layer_id , seq_mode = seq_mode , is_last_layer = is_last_layer [layer_id ])
459+ _ = model .forward_block (inps [j ], state = None , i = layer_id , seq_mode = seq_mode )
460+
457461 model .free_gptq ()
458462
459- if not is_last_layer [ layer_id ]:
460- # We need to pass the outputs of block i as input of block i+1 (except for last block)
463+ # We need to pass the outputs of block i as input of block i+1 (except for last block)
464+ if not is_last_layer ( layer_id ):
461465 inps , outs = outs , inps
462466
463- # TODO: create a function that check if all weights were properly quantized
464- print ("Done" )
467+ end_time = time .time ()
468+
469+ print (f"Done in { end_time - start_time :.2f} seconds" )
470+
471+ # TODO: Do something with quantizers dictionary
472+ # TODO: pack3 save model
0 commit comments