@@ -365,6 +365,9 @@ def prepare_int4_weight_and_scales_and_zeros(weight_bf16, groupsize, inner_k_til
365
365
weight_int4pack = torch .ops .aten ._convert_weight_to_int4pack (weight_int32 , inner_k_tiles )
366
366
return weight_int4pack , scales_and_zeros
367
367
368
+ def _calc_padded_size (k , groupsize = 1 , innner_k_tiles = 1 ):
369
+ from model import find_multiple
370
+ return find_multiple (k , 1024 )
368
371
369
372
def linear_forward_int4 (x , weight_int4pack , scales_and_zeros , out_features , groupsize ):
370
373
origin_x_size = x .size ()
@@ -378,29 +381,24 @@ def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, grou
378
381
def _check_linear_int4_k (k , groupsize = 1 , inner_k_tiles = 1 ):
379
382
return k % groupsize == 0 and k % (inner_k_tiles * 16 ) == 0
380
383
381
- def replace_linear_int4 (module , groupsize , inner_k_tiles , padding , use_cuda ):
384
+ def replace_linear_int4 (module , groupsize , inner_k_tiles , padding_allowed , use_cuda ):
382
385
for name , child in module .named_children ():
383
386
if isinstance (child , nn .Linear ):
384
- if _check_linear_int4_k (child .in_features , groupsize , inner_k_tiles ):
387
+ if _check_linear_int4_k (child .in_features , groupsize , inner_k_tiles ) or padding_allowed :
385
388
setattr (module , name , WeightOnlyInt4Linear (
386
389
child .in_features , child .out_features , bias = False ,
387
- groupsize = groupsize , inner_k_tiles = inner_k_tiles , padding = False , use_cuda = use_cuda
388
- ))
389
- elif padding :
390
- setattr (module , name , WeightOnlyInt4Linear (
391
- child .in_features , child .out_features , bias = False ,
392
- groupsize = groupsize , inner_k_tiles = inner_k_tiles , padding = True , use_cuda = use_cuda
390
+ groupsize = groupsize , inner_k_tiles = inner_k_tiles , use_cuda = use_cuda
393
391
))
394
392
else :
395
- replace_linear_int4 (child , groupsize , inner_k_tiles , padding , use_cuda )
393
+ replace_linear_int4 (child , groupsize , inner_k_tiles , padding_allowed , use_cuda )
396
394
397
395
398
396
class WeightOnlyInt4QuantHandler :
399
- def __init__ (self , mod , groupsize = 128 , inner_k_tiles = 8 , padding = True ):
397
+ def __init__ (self , mod , groupsize = 128 , inner_k_tiles = 8 , padding_allowed = True ):
400
398
self .mod = mod
401
399
self .groupsize = groupsize
402
400
self .inner_k_tiles = inner_k_tiles
403
- self .padding = padding
401
+ self .padding_allowed = padding_allowed
404
402
assert groupsize in [32 , 64 , 128 , 256 ]
405
403
assert inner_k_tiles in [2 , 4 , 8 ]
406
404
@@ -417,7 +415,7 @@ def create_quantized_state_dict(self):
417
415
418
416
weight = mod .weight .data
419
417
if not _check_linear_int4_k (in_features , self .groupsize , self .inner_k_tiles ):
420
- if self .padding :
418
+ if self .padding_allowed :
421
419
from model import find_multiple
422
420
import torch .nn .functional as F
423
421
print (f"warning: { fqn } is padded to satisfy in_features % 1024 == 0" )
@@ -436,7 +434,7 @@ def create_quantized_state_dict(self):
436
434
return cur_state_dict
437
435
438
436
def convert_for_runtime (self , use_cuda ):
439
- replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding , use_cuda )
437
+ replace_linear_int4 (self .mod , self .groupsize , self .inner_k_tiles , self .padding_allowed , use_cuda )
440
438
return self .mod
441
439
442
440
class WeightOnlyInt4GPTQQuantHandler (GPTQQuantHandler ):
@@ -460,7 +458,10 @@ def __init__(self, mod, groupsize=128, inner_k_tiles=8, padding=True):
460
458
# we need to do the padding here, both for q and the qparams if necessary
461
459
def make_names_and_values_dict_func (q , qparams ):
462
460
k = q .shape [1 ]
463
- new_k = find_multiple (k , 1024 )
461
+ if not _check_linear_int4_k (k , groupsize , inner_k_tiles ):
462
+ new_k = find_multiple (k , 1024 )
463
+ else :
464
+ new_k = k
464
465
# how much we need to pad the weight
465
466
delta_k = new_k - q .shape [1 ]
466
467
final_q = torch .ops .aten ._convert_weight_to_int4pack (F .pad (q , pad = (0 , delta_k )), inner_k_tiles )
@@ -485,11 +486,11 @@ class WeightOnlyInt4Linear(torch.nn.Module):
485
486
486
487
def __init__ (
487
488
self , in_features : int , out_features : int ,
488
- bias = True , device = None , dtype = None , groupsize : int = 128 , inner_k_tiles : int = 8 , padding : bool = True , use_cuda = True ,
489
+ bias = True , device = None , dtype = None , groupsize : int = 128 , inner_k_tiles : int = 8 , use_cuda = True ,
489
490
) -> None :
490
491
super ().__init__ ()
491
- self .padding = padding
492
- if padding :
492
+ self .padding = not _check_linear_int4_k ( in_features , groupsize , inner_k_tiles )
493
+ if self . padding :
493
494
from model import find_multiple
494
495
self .origin_in_features = in_features
495
496
in_features = find_multiple (in_features , 1024 )
@@ -502,16 +503,10 @@ def __init__(
502
503
503
504
assert out_features % 8 == 0 , "require out_features % 8 == 0"
504
505
assert in_features % (inner_k_tiles * 16 ) == 0 , "require in_features % (innerKTiles * 16) == 0"
505
- if use_cuda :
506
- self .register_buffer (
507
- "weight" ,
508
- torch .empty ((out_features // 8 , in_features // (inner_k_tiles * 16 ), 32 , inner_k_tiles // 2 ), dtype = torch .int32 )
509
- )
510
- else :
511
- self .register_buffer (
512
- "weight" ,
513
- torch .empty ((out_features , in_features // 2 ), dtype = torch .uint8 )
514
- )
506
+ self .register_buffer (
507
+ "weight" ,
508
+ torch .empty ((out_features // 8 , in_features // (inner_k_tiles * 16 ), 32 , inner_k_tiles // 2 ), dtype = torch .int32 )
509
+ )
515
510
self .register_buffer (
516
511
"scales_and_zeros" ,
517
512
torch .empty ((in_features // groupsize , out_features , 2 ), dtype = torch .bfloat16 )
@@ -544,7 +539,7 @@ def quantize(
544
539
device : str = default_device ,
545
540
) -> None :
546
541
assert checkpoint_path .is_file (), checkpoint_path
547
-
542
+ device = 'cpu'
548
543
precision = torch .bfloat16
549
544
550
545
print ("Loading model ..." )
@@ -554,6 +549,8 @@ def quantize(
554
549
model = Transformer .from_name (checkpoint_path .parent .name )
555
550
556
551
checkpoint = torch .load (str (checkpoint_path ), mmap = True , weights_only = True )
552
+ if "model" in checkpoint and "stories" in str (checkpoint_path ):
553
+ checkpoint = checkpoint ["model" ]
557
554
model .load_state_dict (checkpoint , assign = True )
558
555
model = model .to (dtype = precision , device = device )
559
556
@@ -597,7 +594,7 @@ def quantize(
597
594
598
595
dir_name = checkpoint_path .parent
599
596
base_name = checkpoint_path .name
600
- new_base_name = base_name .replace ('.pth' , f"{ label } int4-gptq.g{ groupsize } .pth" )
597
+ new_base_name = base_name .replace ('.pth' , f"{ label } int4-gptq.g{ groupsize } .{ device } . pth" )
601
598
else :
602
599
raise ValueError (f"Invalid quantization mode { mode } needs to be one of [int8, int4, int4-gpptq]" )
603
600
0 commit comments