9
9
from keras .src import quantizers
10
10
from keras .src import regularizers
11
11
from keras .src .api_export import keras_export
12
- from keras .src .dtype_policies .dtype_policy import GPTQDTypePolicy
13
- from keras .src .dtype_policies .dtype_policy_map import DTypePolicyMap
14
12
from keras .src .layers .input_spec import InputSpec
15
13
from keras .src .layers .layer import Layer
16
- from keras .src .quantizers .gptq_config import GPTQConfig
17
14
from keras .src .quantizers .quantizers import dequantize_with_sz_map
18
15
19
16
@@ -143,22 +140,47 @@ def build(self, input_shape):
143
140
144
141
@property
145
142
def kernel (self ):
143
+ from keras .src .quantizers import gptq_core
144
+
146
145
if not self .built :
147
146
raise AttributeError (
148
147
"You must build the layer before accessing `kernel`."
149
148
)
150
- if (
151
- getattr (self , "is_gptq_calibrated" , False )
152
- and self .quantization_mode == "gptq"
153
- ):
154
- return self .quantized_kernel
155
- kernel = self ._kernel
156
- if self .quantization_mode == "int4" :
157
- kernel = quantizers .unpack_int4 (kernel , self ._orig_input_dim )
149
+
150
+ mode = self .quantization_mode
151
+ is_gptq = mode == "gptq"
152
+ is_int4 = mode == "int4"
153
+ calibrated = bool (getattr (self , "is_gptq_calibrated" , False ))
154
+ gptq_bits = (
155
+ gptq_core .get_weight_bits_for_layer (self , None ) if is_gptq else None
156
+ )
157
+
158
+ # Decide the source tensor first (packed vs already-quantized vs plain
159
+ # kernel)
160
+ if is_gptq and calibrated and gptq_bits != 4 :
161
+ # calibrated GPTQ, not 4-bit, no unpacking needed
162
+ kernel = self .quantized_kernel
163
+ else :
164
+ # Start with the stored kernel
165
+ kernel = getattr (self , "_kernel" , None )
166
+
167
+ # Handle int4 unpacking cases in one place
168
+ if is_int4 :
169
+ kernel = quantizers .unpack_int4 (kernel , self ._orig_input_dim )
170
+ elif is_gptq and calibrated and gptq_bits == 4 :
171
+ kernel = quantizers .unpack_int4 (
172
+ self .quantized_kernel ,
173
+ orig_len = self .units ,
174
+ axis = 0 ,
175
+ dtype = "uint8" ,
176
+ )
177
+
178
+ # Apply LoRA once at the end.
158
179
if self .lora_enabled :
159
- return kernel + (self .lora_alpha / self .lora_rank ) * ops .matmul (
180
+ kernel = kernel + (self .lora_alpha / self .lora_rank ) * ops .matmul (
160
181
self .lora_kernel_a , self .lora_kernel_b
161
182
)
183
+
162
184
return kernel
163
185
164
186
def call (self , inputs , training = None ):
@@ -414,23 +436,33 @@ def _int8_build(self, kernel_shape):
414
436
)
415
437
416
438
def _gptq_build (self , kernel_shape , config ):
439
+ from keras .src .quantizers import gptq_core
440
+
417
441
# Ensures the forward pass uses the original high-precision kernel
418
442
# until calibration has been performed.
419
443
self .is_gptq_calibrated = False
420
444
self .kernel_shape = kernel_shape
445
+
446
+ weight_bits = gptq_core .get_weight_bits_for_layer (self , config )
447
+ # For 4-bit weights, we pack two values per byte.
448
+ units = (
449
+ (kernel_shape [1 ] + 1 ) // 2 if weight_bits == 4 else kernel_shape [1 ]
450
+ )
451
+
421
452
self .quantized_kernel = self .add_weight (
422
453
name = "kernel" ,
423
- shape = (kernel_shape [ 1 ] , kernel_shape [0 ]),
454
+ shape = (units , kernel_shape [0 ]),
424
455
initializer = "zeros" ,
425
456
dtype = "uint8" ,
426
457
trainable = False ,
427
458
)
428
459
429
- group_size = self ._get_gptq_group_size (config )
430
- if group_size == - 1 :
431
- n_groups = 1
432
- else :
433
- n_groups = math .ceil (self .kernel_shape [0 ] / group_size )
460
+ group_size = gptq_core .get_group_size_for_layer (self , config )
461
+ n_groups = (
462
+ 1
463
+ if group_size == - 1
464
+ else math .ceil (self .kernel_shape [0 ] / group_size )
465
+ )
434
466
self .kernel_scale = self .add_weight (
435
467
name = "kernel_scale" ,
436
468
shape = (self .units , n_groups ),
@@ -453,18 +485,31 @@ def _gptq_build(self, kernel_shape, config):
453
485
)
454
486
455
487
def _gptq_call (self , inputs , training = False ):
488
+ from keras .src .quantizers import gptq_core
489
+
456
490
if not self .is_gptq_calibrated :
457
491
W = self ._kernel
458
492
else :
493
+ should_unpack = (
494
+ gptq_core .get_weight_bits_for_layer (self , config = None ) == 4
495
+ )
459
496
W = (
460
- ops .transpose (
461
- dequantize_with_sz_map (
462
- self .quantized_kernel ,
463
- self .kernel_scale ,
464
- self .kernel_zero ,
465
- self .g_idx ,
466
- )
467
- ),
497
+ quantizers .unpack_int4 (
498
+ self .quantized_kernel ,
499
+ orig_len = self .units ,
500
+ axis = 0 ,
501
+ dtype = "uint8" ,
502
+ )
503
+ if should_unpack
504
+ else self .quantized_kernel
505
+ )
506
+ W = ops .transpose (
507
+ dequantize_with_sz_map (
508
+ W ,
509
+ self .kernel_scale ,
510
+ self .kernel_zero ,
511
+ self .g_idx ,
512
+ )
468
513
)
469
514
470
515
y = ops .matmul (inputs , W )
@@ -875,43 +920,3 @@ def _get_kernel_with_merged_lora(self):
875
920
else :
876
921
kernel_value = requantized_kernel
877
922
return kernel_value , kernel_scale
878
-
879
- def _get_gptq_group_size (self , config ):
880
- """Determine the group size for GPTQ quantization.
881
-
882
- The group size can be specified either through the `config` argument
883
- or through the `dtype_policy` if it is of type `GPTQDTypePolicy`.
884
-
885
- The config argument is usually available when quantizing the layer
886
- via the `quantize` method. If the layer was deserialized from a
887
- saved model, the group size should be specified in the `dtype_policy`.
888
-
889
- Args:
890
- config: An optional configuration object that may contain the
891
- `group_size` attribute.
892
- Returns:
893
- int. The determined group size for GPTQ quantization.
894
- Raises:
895
- ValueError: If the group size is not specified in either the
896
- `config` or the `dtype_policy`.
897
- """
898
- if config and isinstance (config , GPTQConfig ):
899
- return config .group_size
900
- elif isinstance (self .dtype_policy , GPTQDTypePolicy ):
901
- return self .dtype_policy .group_size
902
- elif isinstance (self .dtype_policy , DTypePolicyMap ):
903
- policy = self .dtype_policy [self .path ]
904
- if not isinstance (policy , GPTQDTypePolicy ):
905
- # This should never happen based on how we set the
906
- # quantization mode, but we check just in case.
907
- raise ValueError (
908
- "Expected a `dtype_policy` of type `GPTQDTypePolicy`."
909
- f"Got: { type (policy )} "
910
- )
911
- return policy .group_size
912
- else :
913
- raise ValueError (
914
- "For GPTQ quantization, the group_size must be specified"
915
- "either through a `dtype_policy` of type "
916
- "`GPTQDTypePolicy` or the `config` argument."
917
- )
0 commit comments