19
19
20
20
from model import Transformer
21
21
22
- default_device = 'cuda' if torch .cuda .is_available () else 'cpu'
23
-
24
22
##### Quantization Primitives ######
25
23
26
24
def dynamically_quantize_per_channel (x , quant_min , quant_max , target_dtype ):
@@ -328,8 +326,8 @@ def create_quantized_state_dict(self):
328
326
for fqn , mod in self .mod .named_modules ():
329
327
if isinstance (mod , torch .nn .Linear ):
330
328
int8_weight , scales , _ = dynamically_quantize_per_channel (mod .weight .float (), - 128 , 127 , torch .int8 )
331
- cur_state_dict [f"{ fqn } .weight" ] = int8_weight . to ( 'cpu' )
332
- cur_state_dict [f"{ fqn } .scales" ] = scales .to (mod .weight .dtype ). to ( 'cpu' )
329
+ cur_state_dict [f"{ fqn } .weight" ] = int8_weight
330
+ cur_state_dict [f"{ fqn } .scales" ] = scales .to (mod .weight .dtype )
333
331
334
332
return cur_state_dict
335
333
@@ -365,9 +363,6 @@ def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_til
365
363
weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (weight_int32 , inner_k_tiles )
366
364
return weight_int4pack , scales_and_zeros
367
365
368
- def _calc_padded_size (k , groupsize = 1 , innner_k_tiles = 1 ):
369
- from model import find_multiple
370
- return find_multiple (k , 1024 )
371
366
372
367
def linear_forward_int4 (x , weight_int4pack , scales_and_zeros , out_features , groupsize ):
373
368
origin_x_size = x .size ()
@@ -381,29 +376,39 @@ def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, grou
381
376
def _check_linear_int4_k (k , groupsize = 1 , inner_k_tiles = 1 ):
382
377
return k % groupsize == 0 and k % (inner_k_tiles * 16 ) == 0
383
378
384
- def replace_linear_int4 (module , groupsize , inner_k_tiles , padding_allowed , use_cuda ):
379
+ def replace_linear_int4 (module , groupsize , inner_k_tiles , padding ):
385
380
for name , child in module .named_children ():
386
381
if isinstance (child , nn .Linear ):
387
- if _check_linear_int4_k (child .in_features , groupsize , inner_k_tiles ) or padding_allowed :
382
+ if _check_linear_int4_k (child .in_features , groupsize , inner_k_tiles ):
383
+ setattr (module , name , WeightOnlyInt4Linear (
384
+ child .in_features , child .out_features , bias = False ,
385
+ groupsize = groupsize , inner_k_tiles = inner_k_tiles , padding = False ,
386
+ ))
387
+ elif padding :
388
388
setattr (module , name , WeightOnlyInt4Linear (
389
389
child .in_features , child .out_features , bias = False ,
390
- groupsize = groupsize , inner_k_tiles = inner_k_tiles , use_cuda = use_cuda
390
+ groupsize = groupsize , inner_k_tiles = inner_k_tiles , padding = True ,
391
391
))
392
392
else :
393
- replace_linear_int4 (child , groupsize , inner_k_tiles , padding_allowed , use_cuda )
393
+ replace_linear_int4 (child , groupsize , inner_k_tiles , padding )
394
394
395
395
396
396
class WeightOnlyInt4QuantHandler :
397
- def __init__ (self , mod , groupsize = 128 , inner_k_tiles = 8 , padding_allowed = True ):
397
+ def __init__ (self , mod , groupsize = 128 , inner_k_tiles = 8 , padding = True ):
398
398
self .mod = mod
399
399
self .groupsize = groupsize
400
400
self .inner_k_tiles = inner_k_tiles
401
- self .padding_allowed = padding_allowed
401
+ self .padding = padding
402
402
assert groupsize in [32 , 64 , 128 , 256 ]
403
403
assert inner_k_tiles in [2 , 4 , 8 ]
404
404
405
405
@torch .no_grad ()
406
- def create_quantized_state_dict (self ):
406
+ def create_quantized_state_dict (self , use_cuda = True ):
407
+ if use_cuda :
408
+ device = "cuda"
409
+ else :
410
+ device = "cpu"
411
+
407
412
cur_state_dict = self .mod .state_dict ()
408
413
for fqn , mod in self .mod .named_modules ():
409
414
if isinstance (mod , torch .nn .Linear ):
@@ -415,7 +420,7 @@ def create_quantized_state_dict(self):
415
420
416
421
weight = mod .weight .data
417
422
if not _check_linear_int4_k (in_features , self .groupsize , self .inner_k_tiles ):
418
- if self .padding_allowed :
423
+ if self .padding :
419
424
from model import find_multiple
420
425
import torch .nn .functional as F
421
426
print (f"warning: { fqn } is padded to satisfy in_features % 1024 == 0" )
@@ -426,15 +431,15 @@ def create_quantized_state_dict(self):
426
431
"and that groupsize and inner_k_tiles*16 evenly divide into it" )
427
432
continue
428
433
weight_int4pack , scales_and_zeros = prepare_int4_weight_and_scales_and_zeros (
429
- weight .to (torch .bfloat16 ), self .groupsize , self .inner_k_tiles
434
+ weight .to (torch .bfloat16 ). to ( device = device ) , self .groupsize , self .inner_k_tiles
430
435
)
431
436
cur_state_dict [f"{ fqn } .weight" ] = weight_int4pack .to ('cpu' )
432
437
cur_state_dict [f"{ fqn } .scales_and_zeros" ] = scales_and_zeros .to ('cpu' )
433
438
434
439
return cur_state_dict
435
440
436
- def convert_for_runtime (self , use_cuda ):
437
- replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding_allowed , use_cuda )
441
+ def convert_for_runtime (self ):
442
+ replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding )
438
443
return self .mod
439
444
440
445
class WeightOnlyInt4GPTQQuantHandler (GPTQQuantHandler ):
@@ -458,10 +463,7 @@ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
458
463
# we need to do the padding here, both for q and the qparams if necessary
459
464
def make_names_and_values_dict_func (q , qparams ):
460
465
k = q .shape [1 ]
461
- if not _check_linear_int4_k (k , groupsize , inner_k_tiles ):
462
- new_k = find_multiple (k , 1024 )
463
- else :
464
- new_k = k
466
+ new_k = find_multiple (k , 1024 )
465
467
# how much we need to pad the weight
466
468
delta_k = new_k - q .shape [1 ]
467
469
final_q = torch .ops .aten ._convert_weight_to_int4pack (F .pad (q , pad = (0 , delta_k )), inner_k_tiles )
@@ -474,8 +476,8 @@ def make_names_and_values_dict_func(q, qparams):
474
476
super ().__init__ ()
475
477
476
478
477
- def convert_for_runtime (self , use_cuda ):
478
- replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding , use_cuda )
479
+ def convert_for_runtime (self ):
480
+ replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding )
479
481
return self .mod
480
482
481
483
class WeightOnlyInt4Linear (torch .nn .Module ):
@@ -486,11 +488,11 @@ class WeightOnlyInt4Linear(torch.nn.Module):
486
488
487
489
def __init__ (
488
490
self , in_features : int , out_features : int ,
489
- bias = True , device = None , dtype = None , groupsize : int = 128 , inner_k_tiles : int = 8 , use_cuda = True ,
491
+ bias = True , device = None , dtype = None , groupsize : int = 128 , inner_k_tiles : int = 8 , padding : bool = True ,
490
492
) -> None :
491
493
super ().__init__ ()
492
- self .padding = not _check_linear_int4_k ( in_features , groupsize , inner_k_tiles )
493
- if self . padding :
494
+ self .padding = padding
495
+ if padding :
494
496
from model import find_multiple
495
497
self .origin_in_features = in_features
496
498
in_features = find_multiple (in_features , 1024 )
@@ -536,9 +538,9 @@ def quantize(
536
538
percdamp : float = .01 ,
537
539
blocksize : int = 128 ,
538
540
label : str = '' ,
539
- device : str = default_device ,
540
541
) -> None :
541
542
assert checkpoint_path .is_file (), checkpoint_path
543
+
542
544
device = 'cpu'
543
545
precision = torch .bfloat16
544
546
@@ -549,8 +551,6 @@ def quantize(
549
551
model = Transformer .from_name (checkpoint_path .parent .name )
550
552
551
553
checkpoint = torch .load (str (checkpoint_path ), mmap = True , weights_only = True )
552
- if "model" in checkpoint and "stories" in str (checkpoint_path ):
553
- checkpoint = checkpoint ["model" ]
554
554
model .load_state_dict (checkpoint , assign = True )
555
555
model = model .to (dtype = precision , device = device )
556
556
@@ -565,13 +565,12 @@ def quantize(
565
565
566
566
elif mode == 'int4' :
567
567
print ("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization" )
568
- print (f"Prepacking model weights in { device } optimal layout" )
569
568
quant_handler = WeightOnlyInt4QuantHandler (model , groupsize )
570
569
quantized_state_dict = quant_handler .create_quantized_state_dict ()
571
570
572
571
dir_name = checkpoint_path .parent
573
572
base_name = checkpoint_path .name
574
- new_base_name = base_name .replace ('.pth' , f"{ label } int4.g{ groupsize } .{ device } . pth" )
573
+ new_base_name = base_name .replace ('.pth' , f"{ label } int4.g{ groupsize } .pth" )
575
574
576
575
elif mode == 'int4-gptq' :
577
576
print ("Quantizing model weights for int4 weight-only affine per-channel groupwise quantization using GPTQ..." )
@@ -594,7 +593,7 @@ def quantize(
594
593
595
594
dir_name = checkpoint_path .parent
596
595
base_name = checkpoint_path .name
597
- new_base_name = base_name .replace ('.pth' , f"{ label } int4-gptq.g{ groupsize } .{ device } . pth" )
596
+ new_base_name = base_name .replace ('.pth' , f"{ label } int4-gptq.g{ groupsize } .pth" )
598
597
else :
599
598
raise ValueError (f"Invalid quantization mode { mode } needs to be one of [int8, int4, int4-gpptq]" )
600
599
@@ -618,7 +617,6 @@ def quantize(
618
617
parser .add_argument ('--percdamp' , type = float , default = .01 , help = 'gptq percentage dampening' )
619
618
parser .add_argument ('--blocksize' , type = int , default = 128 , help = 'blocksize for gptq' )
620
619
parser .add_argument ('--label' , type = str , default = '_' , help = 'label to add to output filename' )
621
- parser .add_argument ('--device' , type = str , default = default_device , help = 'device to use' )
622
620
623
621
args = parser .parse_args ()
624
- quantize (args .checkpoint_path , args .mode , args .groupsize , args .calibration_tasks , args .calibration_limit , args .calibration_seq_length , args .pad_calibration_inputs , args .percdamp , args .blocksize , args .label , args . device )
622
+ quantize (args .checkpoint_path , args .mode , args .groupsize , args .calibration_tasks , args .calibration_limit , args .calibration_seq_length , args .pad_calibration_inputs , args .percdamp , args .blocksize , args .label )
0 commit comments