2
2
3
3
from keras .src import activations
4
4
from keras .src import constraints
5
- from keras .src import dtype_policies
6
5
from keras .src import initializers
7
6
from keras .src import ops
8
7
from keras .src import quantizers
@@ -110,9 +109,10 @@ def build(self, input_shape):
110
109
kernel_shape = (input_shape [- 1 ], self .units )
111
110
if self .quantization_mode :
112
111
self .quantized_build (kernel_shape , mode = self .quantization_mode )
113
- if self .quantization_mode != "int8" :
114
- # If the layer is quantized to int8, `self._kernel` will be added
115
- # in `self._int8_build`. Therefore, we skip it here.
112
+ if self .quantization_mode not in ("int8" , "int4" ):
113
+ # If the layer is quantized to int8 or int4, `self._kernel` will be
114
+ # added in `self._int8_build` or `_int4_build`. Therefore, we skip
115
+ # it here.
116
116
self ._kernel = self .add_weight (
117
117
name = "kernel" ,
118
118
shape = kernel_shape ,
@@ -182,9 +182,22 @@ def enable_lora(
182
182
"lora is already enabled. This can only be done once per layer."
183
183
)
184
184
self ._tracker .unlock ()
185
+ # Determine the correct input dimension for the LoRA A matrix. When
186
+ # the layer has been int4-quantized, `self._kernel` stores a *packed*
187
+ # representation whose first dimension is `ceil(input_dim/2)`. We
188
+ # saved the true, *unpacked* input dimension in `self._orig_input_dim`
189
+ # during quantization. Use it if available; otherwise fall back to the
190
+ # first dimension of `self.kernel`.
191
+ if self .quantization_mode == "int4" and hasattr (
192
+ self , "_orig_input_dim"
193
+ ):
194
+ input_dim_for_lora = self ._orig_input_dim
195
+ else :
196
+ input_dim_for_lora = self .kernel .shape [0 ]
197
+
185
198
self .lora_kernel_a = self .add_weight (
186
199
name = "lora_kernel_a" ,
187
- shape = (self . kernel . shape [ 0 ] , rank ),
200
+ shape = (input_dim_for_lora , rank ),
188
201
initializer = initializers .get (a_initializer ),
189
202
regularizer = self .kernel_regularizer ,
190
203
)
@@ -211,7 +224,7 @@ def save_own_variables(self, store):
211
224
if self .use_bias :
212
225
target_variables .append (self .bias )
213
226
if self .quantization_mode is not None :
214
- if self .quantization_mode == "int8" :
227
+ if self .quantization_mode in ( "int8" , "int4" ) :
215
228
target_variables .append (kernel_scale )
216
229
elif self .quantization_mode == "float8" :
217
230
target_variables .append (self .inputs_scale )
@@ -237,7 +250,7 @@ def load_own_variables(self, store):
237
250
if self .use_bias :
238
251
target_variables .append (self .bias )
239
252
if self .quantization_mode is not None :
240
- if self .quantization_mode == "int8" :
253
+ if self .quantization_mode in ( "int8" , "int4" ) :
241
254
target_variables .append (self .kernel_scale )
242
255
elif self .quantization_mode == "float8" :
243
256
target_variables .append (self .inputs_scale )
@@ -315,6 +328,8 @@ def _check_load_own_variables(self, store):
315
328
def quantized_build (self , kernel_shape , mode ):
316
329
if mode == "int8" :
317
330
self ._int8_build (kernel_shape )
331
+ elif mode == "int4" :
332
+ self ._int4_build (kernel_shape )
318
333
elif mode == "float8" :
319
334
self ._float8_build ()
320
335
else :
@@ -337,6 +352,39 @@ def _int8_build(self, kernel_shape):
337
352
trainable = False ,
338
353
)
339
354
355
+ def _int4_build (self , kernel_shape ):
356
+ """Build variables for int4 quantization.
357
+
358
+ `kernel_shape` is the *original* float32 kernel shape
359
+ `(input_dim, units)`. We allocate the stored kernel with rows
360
+ `ceil(input_dim/2)` because two int4 values are packed into a single
361
+ int8 byte.
362
+ """
363
+ # Per-channel int8 quantizer for the last axis (features).
364
+ self .inputs_quantizer = quantizers .AbsMaxQuantizer (
365
+ axis = - 1 ,
366
+ )
367
+ input_dim , output_dim = kernel_shape
368
+ packed_rows = (input_dim + 1 ) // 2 # ceil for odd dims
369
+
370
+ # Kernel is stored *packed*: each int8 byte contains two int4 values.
371
+ self ._kernel = self .add_weight (
372
+ name = "kernel" ,
373
+ shape = (packed_rows , output_dim ),
374
+ initializer = "zeros" ,
375
+ dtype = "int8" ,
376
+ trainable = False ,
377
+ )
378
+ # One scale per output unit (per-channel).
379
+ self .kernel_scale = self .add_weight (
380
+ name = "kernel_scale" ,
381
+ shape = (self .units ,),
382
+ initializer = "ones" ,
383
+ trainable = False ,
384
+ )
385
+ # Record original input_dim for unpacking at runtime.
386
+ self ._orig_input_dim = input_dim
387
+
340
388
def _float8_build (self ):
341
389
from keras .src .dtype_policies import QuantizedFloat8DTypePolicy
342
390
@@ -383,6 +431,16 @@ def _float8_build(self):
383
431
def _int8_call (self , inputs , training = None ):
384
432
@ops .custom_gradient
385
433
def matmul_with_inputs_gradient (inputs , kernel , kernel_scale ):
434
+ """Custom gradient function to handle the int8 quantized weights.
435
+
436
+ Automatic differentiation will not know how to handle the int8
437
+ quantized weights. So a custom gradient function is needed to
438
+ handle the int8 quantized weights.
439
+
440
+ The custom gradient function will use the dequantized kernel to
441
+ compute the gradient.
442
+ """
443
+
386
444
def grad_fn (* args , upstream = None ):
387
445
if upstream is None :
388
446
(upstream ,) = args
@@ -415,6 +473,59 @@ def grad_fn(*args, upstream=None):
415
473
x = self .activation (x )
416
474
return x
417
475
476
+ def _int4_call (self , inputs , training = None ):
477
+ """Forward pass for int4 quantized Dense layer."""
478
+
479
+ @ops .custom_gradient
480
+ def matmul_with_inputs_gradient (inputs , kernel , kernel_scale ):
481
+ """Custom gradient function for int4 quantized weights.
482
+
483
+ Automatic differentiation will not know how to handle the
484
+ int4 quantized weights. So a custom gradient function is needed
485
+ to handle the int4 quantized weights.
486
+
487
+ The custom gradient function will use the dequantized kernel to
488
+ compute the gradient.
489
+ """
490
+
491
+ unpacked_kernel = quantizers .unpack_int4 (
492
+ kernel , self ._orig_input_dim
493
+ )
494
+
495
+ def grad_fn (* args , upstream = None ):
496
+ if upstream is None :
497
+ (upstream ,) = args
498
+ float_kernel = ops .divide (
499
+ ops .cast (unpacked_kernel , dtype = self .compute_dtype ),
500
+ kernel_scale ,
501
+ )
502
+ inputs_grad = ops .matmul (upstream , ops .transpose (float_kernel ))
503
+ return (inputs_grad , None , None )
504
+
505
+ inputs , inputs_scale = self .inputs_quantizer (inputs )
506
+ x = ops .matmul (inputs , unpacked_kernel )
507
+ x = ops .cast (x , self .compute_dtype )
508
+ x = ops .divide (x , ops .multiply (inputs_scale , kernel_scale ))
509
+ return x , grad_fn
510
+
511
+ x = matmul_with_inputs_gradient (
512
+ inputs ,
513
+ ops .convert_to_tensor (self ._kernel ),
514
+ ops .convert_to_tensor (self .kernel_scale ),
515
+ )
516
+
517
+ if self .lora_enabled :
518
+ lora_x = ops .matmul (inputs , self .lora_kernel_a )
519
+ lora_x = ops .matmul (lora_x , self .lora_kernel_b )
520
+ x = ops .add (x , (self .lora_alpha / self .lora_rank ) * lora_x )
521
+
522
+ # Add bias and activation
523
+ if self .bias is not None :
524
+ x = ops .add (x , self .bias )
525
+ if self .activation is not None :
526
+ x = self .activation (x )
527
+ return x
528
+
418
529
def _float8_call (self , inputs , training = None ):
419
530
if self .lora_enabled :
420
531
raise NotImplementedError (
@@ -518,32 +629,117 @@ def quantize(self, mode, type_check=True):
518
629
)
519
630
kernel_scale = ops .squeeze (kernel_scale , axis = 0 )
520
631
del self ._kernel
521
- self . quantized_build ( kernel_shape , mode )
522
- if mode == "int8" :
632
+ # Build variables for int8 mode
633
+ self . quantized_build ( kernel_shape , mode )
523
634
self ._kernel .assign (kernel_value )
524
635
self .kernel_scale .assign (kernel_scale )
636
+ elif mode == "int4" :
637
+ # 1. Quantize to int4 values (still int8 dtype, range [-8,7])
638
+ kernel_value_int4 , kernel_scale = quantizers .abs_max_quantize (
639
+ self ._kernel ,
640
+ axis = 0 ,
641
+ value_range = (- 8 , 7 ),
642
+ dtype = "int8" ,
643
+ to_numpy = True ,
644
+ )
645
+ kernel_scale = ops .squeeze (kernel_scale , axis = 0 )
646
+ # 2. Pack two int4 values into a single int8 byte.
647
+ packed_kernel_value , _ , _ = quantizers .pack_int4 (kernel_value_int4 )
648
+ del self ._kernel
649
+ # Build variables using the original kernel shape; _int4_build will
650
+ # compute the packed shape internally.
651
+ self .quantized_build (kernel_shape , mode )
652
+ # Assign packed values.
653
+ self ._kernel .assign (packed_kernel_value )
654
+ self .kernel_scale .assign (kernel_scale )
655
+ elif mode == "float8" :
656
+ self .quantized_build (kernel_shape , mode )
657
+ else :
658
+ raise self ._quantization_mode_error (mode )
525
659
526
- # Set new dtype policy
660
+ # Set new dtype policy only for modes that already have a policy.
527
661
if self .dtype_policy .quantization_mode is None :
662
+ from keras .src import dtype_policies # local import to avoid cycle
663
+
528
664
policy = dtype_policies .get (f"{ mode } _from_{ self .dtype_policy .name } " )
529
665
self .dtype_policy = policy
530
666
531
667
def _get_kernel_with_merged_lora (self ):
668
+ """Returns the kernel with LoRA matrices merged, for serialization.
669
+
670
+ This method is called by `save_own_variables` to produce a single
671
+ kernel tensor that includes the adaptations from LoRA. This is useful
672
+ for deploying the model or for continuing training after permanently
673
+ applying the LoRA update.
674
+
675
+ If the layer is quantized (`int8` or `int4`), the process is:
676
+ 1. Dequantize the base kernel to float.
677
+ 2. Compute the LoRA delta (`lora_kernel_a @ lora_kernel_b`) and add
678
+ it to the dequantized kernel.
679
+ 3. Re-quantize the merged result back to the original quantized
680
+ type (`int8` or packed `int4`), calculating a new scale factor.
681
+
682
+ If the layer is not quantized, this method returns the result of the
683
+ `kernel` property (which computes the merge in floating-point) and a
684
+ scale of `None`.
685
+
686
+ If LoRA is not enabled, it returns the original kernel and scale
687
+ without modification.
688
+
689
+ Returns:
690
+ A tuple `(kernel_value, kernel_scale)`:
691
+ `kernel_value`: The merged kernel. A quantized tensor if
692
+ quantization is active, otherwise a high precision tensor.
693
+ `kernel_scale`: The quantization scale for the merged kernel.
694
+ This is `None` if the layer is not quantized.
695
+ """
532
696
if self .dtype_policy .quantization_mode is not None :
533
697
kernel_value = self ._kernel
534
698
kernel_scale = self .kernel_scale
535
699
if self .lora_enabled :
536
- # Dequantize & quantize to merge lora weights into int8 kernel
537
- # Note that this is a lossy compression
538
- kernel_value = ops .divide (kernel_value , kernel_scale )
539
- kernel_value = ops .add (
540
- kernel_value ,
541
- (self .lora_alpha / self .lora_rank )
542
- * ops .matmul (self .lora_kernel_a , self .lora_kernel_b ),
700
+ # Dequantize kernel to float
701
+ if self .quantization_mode == "int4" :
702
+ unpacked_kernel = quantizers .unpack_int4 (
703
+ kernel_value , self ._orig_input_dim
704
+ )
705
+ float_kernel = ops .divide (
706
+ ops .cast (unpacked_kernel , self .compute_dtype ),
707
+ kernel_scale ,
708
+ )
709
+ quant_range = (- 8 , 7 )
710
+ elif self .quantization_mode == "int8" :
711
+ float_kernel = ops .divide (
712
+ ops .cast (kernel_value , self .compute_dtype ), kernel_scale
713
+ )
714
+ quant_range = (- 127 , 127 )
715
+ else :
716
+ raise ValueError (
717
+ "Unsupported quantization mode: "
718
+ f"{ self .quantization_mode } "
719
+ )
720
+
721
+ # Merge LoRA weights in float domain
722
+ lora_delta = (self .lora_alpha / self .lora_rank ) * ops .matmul (
723
+ self .lora_kernel_a , self .lora_kernel_b
543
724
)
544
- kernel_value , kernel_scale = quantizers .abs_max_quantize (
545
- kernel_value , axis = 0 , to_numpy = True
725
+ merged_float_kernel = ops .add (float_kernel , lora_delta )
726
+
727
+ # Requantize
728
+ requantized_kernel , kernel_scale = quantizers .abs_max_quantize (
729
+ merged_float_kernel ,
730
+ axis = 0 ,
731
+ value_range = quant_range ,
732
+ dtype = "int8" ,
733
+ to_numpy = True ,
546
734
)
547
735
kernel_scale = ops .squeeze (kernel_scale , axis = 0 )
736
+
737
+ # Pack if int4
738
+ if self .quantization_mode == "int4" :
739
+ kernel_value , _ , _ = quantizers .pack_int4 (
740
+ requantized_kernel
741
+ )
742
+ else :
743
+ kernel_value = requantized_kernel
548
744
return kernel_value , kernel_scale
549
745
return self .kernel , None
0 commit comments