Skip to content

Commit df481e9

Browse files
Adds int4 quantization support to EinsumDense (#21471)
* einsum dense int4 support * simplify dense lora kernel merge * Fix docstring * reduce duplicated code * adjust numerics tolerance for int4 * reduce code duplication in int4 and int8 paths
1 parent b6f178e commit df481e9

File tree

3 files changed

+610
-185
lines changed

3 files changed

+610
-185
lines changed

keras/src/layers/core/dense.py

Lines changed: 52 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -693,53 +693,56 @@ def _get_kernel_with_merged_lora(self):
693693
`kernel_scale`: The quantization scale for the merged kernel.
694694
This is `None` if the layer is not quantized.
695695
"""
696-
if self.dtype_policy.quantization_mode is not None:
697-
kernel_value = self._kernel
698-
kernel_scale = self.kernel_scale
699-
if self.lora_enabled:
700-
# Dequantize kernel to float
701-
if self.quantization_mode == "int4":
702-
unpacked_kernel = quantizers.unpack_int4(
703-
kernel_value, self._orig_input_dim
704-
)
705-
float_kernel = ops.divide(
706-
ops.cast(unpacked_kernel, self.compute_dtype),
707-
kernel_scale,
708-
)
709-
quant_range = (-8, 7)
710-
elif self.quantization_mode == "int8":
711-
float_kernel = ops.divide(
712-
ops.cast(kernel_value, self.compute_dtype), kernel_scale
713-
)
714-
quant_range = (-127, 127)
715-
else:
716-
raise ValueError(
717-
"Unsupported quantization mode: "
718-
f"{self.quantization_mode}"
719-
)
720-
721-
# Merge LoRA weights in float domain
722-
lora_delta = (self.lora_alpha / self.lora_rank) * ops.matmul(
723-
self.lora_kernel_a, self.lora_kernel_b
724-
)
725-
merged_float_kernel = ops.add(float_kernel, lora_delta)
726-
727-
# Requantize
728-
requantized_kernel, kernel_scale = quantizers.abs_max_quantize(
729-
merged_float_kernel,
730-
axis=0,
731-
value_range=quant_range,
732-
dtype="int8",
733-
to_numpy=True,
734-
)
735-
kernel_scale = ops.squeeze(kernel_scale, axis=0)
736-
737-
# Pack if int4
738-
if self.quantization_mode == "int4":
739-
kernel_value, _, _ = quantizers.pack_int4(
740-
requantized_kernel
741-
)
742-
else:
743-
kernel_value = requantized_kernel
696+
if self.dtype_policy.quantization_mode is None:
697+
return self.kernel, None
698+
699+
kernel_value = self._kernel
700+
kernel_scale = self.kernel_scale
701+
702+
if not self.lora_enabled:
744703
return kernel_value, kernel_scale
745-
return self.kernel, None
704+
705+
# Dequantize, Merge, and Re-quantize
706+
707+
# Dequantize kernel to float
708+
if self.quantization_mode == "int4":
709+
unpacked_kernel = quantizers.unpack_int4(
710+
kernel_value, self._orig_input_dim
711+
)
712+
float_kernel = ops.divide(
713+
ops.cast(unpacked_kernel, self.compute_dtype),
714+
kernel_scale,
715+
)
716+
quant_range = (-8, 7)
717+
elif self.quantization_mode == "int8":
718+
float_kernel = ops.divide(
719+
ops.cast(kernel_value, self.compute_dtype), kernel_scale
720+
)
721+
quant_range = (-127, 127)
722+
else:
723+
raise ValueError(
724+
f"Unsupported quantization mode: {self.quantization_mode}"
725+
)
726+
727+
# Merge LoRA weights in float domain
728+
lora_delta = (self.lora_alpha / self.lora_rank) * ops.matmul(
729+
self.lora_kernel_a, self.lora_kernel_b
730+
)
731+
merged_float_kernel = ops.add(float_kernel, lora_delta)
732+
733+
# Requantize
734+
requantized_kernel, kernel_scale = quantizers.abs_max_quantize(
735+
merged_float_kernel,
736+
axis=0,
737+
value_range=quant_range,
738+
dtype="int8",
739+
to_numpy=True,
740+
)
741+
kernel_scale = ops.squeeze(kernel_scale, axis=0)
742+
743+
# Pack if int4
744+
if self.quantization_mode == "int4":
745+
kernel_value, _, _ = quantizers.pack_int4(requantized_kernel)
746+
else:
747+
kernel_value = requantized_kernel
748+
return kernel_value, kernel_scale

0 commit comments

Comments
 (0)