diff --git a/keras/src/layers/core/dense.py b/keras/src/layers/core/dense.py index 317d6bbd28fc..8db7e1b22472 100644 --- a/keras/src/layers/core/dense.py +++ b/keras/src/layers/core/dense.py @@ -345,12 +345,14 @@ def _int8_build(self, kernel_shape): dtype="int8", trainable=False, ) + self._kernel._is_quantized = True self.kernel_scale = self.add_weight( name="kernel_scale", shape=(self.units,), initializer="ones", trainable=False, ) + self.kernel_scale._is_quantized = True def _int4_build(self, kernel_shape): """Build variables for int4 quantization. @@ -375,6 +377,7 @@ def _int4_build(self, kernel_shape): dtype="int8", trainable=False, ) + self._kernel._is_quantized = True # One scale per output unit (per-channel). self.kernel_scale = self.add_weight( name="kernel_scale", @@ -382,6 +385,7 @@ def _int4_build(self, kernel_shape): initializer="ones", trainable=False, ) + self.kernel_scale._is_quantized = True # Record original input_dim for unpacking at runtime. self._orig_input_dim = input_dim @@ -414,19 +418,25 @@ def _float8_build(self): "overwrite_with_gradient": True, } self.inputs_scale = self.add_weight(name="inputs_scale", **scale_kwargs) + self.inputs_scale._is_quantized = True self.inputs_amax_history = self.add_weight( name="inputs_amax_history", **amax_history_kwargs ) + self.inputs_amax_history._is_quantized = True self.kernel_scale = self.add_weight(name="kernel_scale", **scale_kwargs) + self.kernel_scale._is_quantized = True self.kernel_amax_history = self.add_weight( name="kernel_amax_history", **amax_history_kwargs ) + self.kernel_amax_history._is_quantized = True self.outputs_grad_scale = self.add_weight( name="outputs_grad_scale", **scale_kwargs ) + self.outputs_grad_scale._is_quantized = True self.outputs_grad_amax_history = self.add_weight( name="outputs_grad_amax_history", **amax_history_kwargs ) + self.outputs_grad_amax_history._is_quantized = True def _int8_call(self, inputs, training=None): @ops.custom_gradient diff --git a/keras/src/layers/core/einsum_dense.py b/keras/src/layers/core/einsum_dense.py index 39100486cda8..e2f99217906e 100644 --- a/keras/src/layers/core/einsum_dense.py +++ b/keras/src/layers/core/einsum_dense.py @@ -412,6 +412,7 @@ def _int8_build(self, kernel_shape): dtype="int8", trainable=False, ) + self._kernel._is_quantized = True kernel_scale_shape = self._get_kernel_scale_shape(kernel_shape) self.kernel_scale = self.add_weight( name="kernel_scale", @@ -419,6 +420,7 @@ def _int8_build(self, kernel_shape): initializer="ones", trainable=False, ) + self.kernel_scale._is_quantized = True def _int4_build(self, kernel_shape): """Build variables for int4 quantization. @@ -461,6 +463,7 @@ def _int4_build(self, kernel_shape): dtype="int8", trainable=False, ) + self._kernel._is_quantized = True # Kernel scale kernel_scale_shape = self._get_kernel_scale_shape(kernel_shape) @@ -470,6 +473,7 @@ def _int4_build(self, kernel_shape): initializer="ones", trainable=False, ) + self.kernel_scale._is_quantized = True def _float8_build(self): from keras.src.dtype_policies import QuantizedFloat8DTypePolicy @@ -500,19 +504,25 @@ def _float8_build(self): "overwrite_with_gradient": True, } self.inputs_scale = self.add_weight(name="inputs_scale", **scale_kwargs) + self.inputs_scale._is_quantized = True self.inputs_amax_history = self.add_weight( name="inputs_amax_history", **amax_history_kwargs ) + self.inputs_amax_history._is_quantized = True self.kernel_scale = self.add_weight(name="kernel_scale", **scale_kwargs) + self.kernel_scale._is_quantized = True self.kernel_amax_history = self.add_weight( name="kernel_amax_history", **amax_history_kwargs ) + self.kernel_amax_history._is_quantized = True self.outputs_grad_scale = self.add_weight( name="outputs_grad_scale", **scale_kwargs ) + self.outputs_grad_scale._is_quantized = True self.outputs_grad_amax_history = self.add_weight( name="outputs_grad_amax_history", **amax_history_kwargs ) + self.outputs_grad_amax_history._is_quantized = True def _int8_call(self, inputs, training=None): @ops.custom_gradient diff --git a/keras/src/layers/core/embedding.py b/keras/src/layers/core/embedding.py index e9a207daf3df..b982c6406c0b 100644 --- a/keras/src/layers/core/embedding.py +++ b/keras/src/layers/core/embedding.py @@ -328,6 +328,7 @@ def _int8_build(self, embeddings_shape): dtype="int8", trainable=False, ) + self._embeddings._is_quantized = True # We choose to reduce the axis of `output_dim` because, typically, # `input_dim` is larger than `output_dim`. This reduces quantization # error. @@ -337,6 +338,7 @@ def _int8_build(self, embeddings_shape): initializer="ones", trainable=False, ) + self.embeddings_scale._is_quantized = True def quantized_call(self, *args, **kwargs): if self.quantization_mode != "int8": diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index dc3d5d9e0c64..c7f189df60f5 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -1359,10 +1359,33 @@ def save_own_variables(self, store): Args: store: Dict where the state of the model will be saved. """ + if not getattr(self, "_is_quantized", False): + all_vars = self._trainable_variables + self._non_trainable_variables + for i, v in enumerate(all_vars): + store[f"{i}"] = v + return + + # Case: quantized layer + quantized_vars = self._get_quantized_variables() + for i, v in enumerate(quantized_vars): + store[f"quantized_{i}"] = v + + # Save non-quantized variables all_vars = self._trainable_variables + self._non_trainable_variables - for i, v in enumerate(all_vars): + quantized_vars_set = set(quantized_vars) + non_quantized_vars = [ + v for v in all_vars if v not in quantized_vars_set + ] + for i, v in enumerate(non_quantized_vars): store[f"{i}"] = v + def _get_quantized_variables(self): + quantized_vars = [] + for v in self._trainable_variables + self._non_trainable_variables: + if getattr(v, "_is_quantized", False): + quantized_vars.append(v) + return quantized_vars + def load_own_variables(self, store): """Loads the state of the layer. @@ -1372,6 +1395,10 @@ def load_own_variables(self, store): Args: store: Dict from which the state of the model will be loaded. """ + if any(key.startswith("quantized_") for key in store.keys()): + self._load_quantized_variables(store) + return + all_vars = self._trainable_variables + self._non_trainable_variables if len(store.keys()) != len(all_vars): if len(all_vars) == 0 and not self.built: @@ -1407,6 +1434,20 @@ def load_own_variables(self, store): for i, v in enumerate(all_vars): v.assign(store[f"{i}"]) + def _load_quantized_variables(self, store): + quantized_vars = self._get_quantized_variables() + for i, v in enumerate(quantized_vars): + v.assign(store[f"quantized_{i}"]) + + # Load non-quantized variables + all_vars = self._trainable_variables + self._non_trainable_variables + quantized_vars_set = set(quantized_vars) + non_quantized_vars = [ + v for v in all_vars if v not in quantized_vars_set + ] + for i, v in enumerate(non_quantized_vars): + v.assign(store[f"{i}"]) + def _track_variable(self, variable): if variable.trainable: self._tracker.add_to_store("trainable_variables", variable) diff --git a/keras/src/layers/layer_test.py b/keras/src/layers/layer_test.py index aa27eb9aac71..974e6aa61534 100644 --- a/keras/src/layers/layer_test.py +++ b/keras/src/layers/layer_test.py @@ -1,3 +1,4 @@ +import os import pickle from unittest import mock @@ -12,6 +13,7 @@ from keras.src import metrics from keras.src import models from keras.src import ops +from keras.src import saving from keras.src import testing from keras.src.backend.common import global_state from keras.src.backend.common.remat import RematScope @@ -1758,3 +1760,16 @@ def call(self, x): # foo_mode omitted -> foo_mode defaults to False -> no change y2 = model(sample_input) self.assertAllClose(y2, sample_input) + + def test_quantized_model_save_and_load(self): + inputs = layers.Input(shape=(None,)) + x = layers.Embedding(input_dim=10, output_dim=10)(inputs) + x = layers.Dense(10)(x) + model = models.Model(inputs=inputs, outputs=x) + path = os.path.join(self.get_temp_dir(), "quantized_model.keras") + model.quantize(mode="int8") + model.save(path) + + quantized_model = saving.load_model(path) + + self.assertTrue(quantized_model.built)