From 886bfbe1fcf6c1735117a81029b47fdce36d12d2 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Wed, 23 Jul 2025 21:50:53 +0000 Subject: [PATCH 1/3] fix quantization save and load error --- keras/src/layers/layer.py | 41 +++++++++++++++++++++++++++++++++- keras/src/layers/layer_test.py | 15 +++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index dc3d5d9e0c64..8aee6478d4c4 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -1359,10 +1359,32 @@ def save_own_variables(self, store): Args: store: Dict where the state of the model will be saved. """ + if not getattr(self, "_is_quantized", False): + all_vars = self._trainable_variables + self._non_trainable_variables + for i, v in enumerate(all_vars): + store[f"{i}"] = v + return + + # Case: quantized layer + quantized_vars = self._get_quantized_variables() + for i, v in enumerate(quantized_vars): + store[f"quantized_{i}"] = v + + # Save non-quantized variables all_vars = self._trainable_variables + self._non_trainable_variables - for i, v in enumerate(all_vars): + non_quantized_vars = [ + v for v in all_vars if v not in quantized_vars and v.trainable + ] + for i, v in enumerate(non_quantized_vars): store[f"{i}"] = v + def _get_quantized_variables(self): + quantized_vars = [] + for v in self._trainable_variables + self._non_trainable_variables: + if not backend.is_float_dtype(v.dtype): + quantized_vars.append(v) + return quantized_vars + def load_own_variables(self, store): """Loads the state of the layer. @@ -1372,6 +1394,10 @@ def load_own_variables(self, store): Args: store: Dict from which the state of the model will be loaded. """ + if any(key.startswith("quantized_") for key in store.keys()): + self._load_quantized_variables(store) + return + all_vars = self._trainable_variables + self._non_trainable_variables if len(store.keys()) != len(all_vars): if len(all_vars) == 0 and not self.built: @@ -1407,6 +1433,19 @@ def load_own_variables(self, store): for i, v in enumerate(all_vars): v.assign(store[f"{i}"]) + def _load_quantized_variables(self, store): + quantized_vars = self._get_quantized_variables() + for i, v in enumerate(quantized_vars): + v.assign(store[f"quantized_{i}"]) + + # Load non-quantized variables + all_vars = self._trainable_variables + self._non_trainable_variables + non_quantized_vars = [ + v for v in all_vars if v not in quantized_vars and v.trainable + ] + for i, v in enumerate(non_quantized_vars): + v.assign(store[f"{i}"]) + def _track_variable(self, variable): if variable.trainable: self._tracker.add_to_store("trainable_variables", variable) diff --git a/keras/src/layers/layer_test.py b/keras/src/layers/layer_test.py index aa27eb9aac71..974e6aa61534 100644 --- a/keras/src/layers/layer_test.py +++ b/keras/src/layers/layer_test.py @@ -1,3 +1,4 @@ +import os import pickle from unittest import mock @@ -12,6 +13,7 @@ from keras.src import metrics from keras.src import models from keras.src import ops +from keras.src import saving from keras.src import testing from keras.src.backend.common import global_state from keras.src.backend.common.remat import RematScope @@ -1758,3 +1760,16 @@ def call(self, x): # foo_mode omitted -> foo_mode defaults to False -> no change y2 = model(sample_input) self.assertAllClose(y2, sample_input) + + def test_quantized_model_save_and_load(self): + inputs = layers.Input(shape=(None,)) + x = layers.Embedding(input_dim=10, output_dim=10)(inputs) + x = layers.Dense(10)(x) + model = models.Model(inputs=inputs, outputs=x) + path = os.path.join(self.get_temp_dir(), "quantized_model.keras") + model.quantize(mode="int8") + model.save(path) + + quantized_model = saving.load_model(path) + + self.assertTrue(quantized_model.built) From eb67b46c80e381407174da01329166160a61f8e9 Mon Sep 17 00:00:00 2001 From: divyashreepathihalli Date: Wed, 23 Jul 2025 22:51:56 +0000 Subject: [PATCH 2/3] address gemini comments --- keras/src/layers/layer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 8aee6478d4c4..8d857f9810bd 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -1372,8 +1372,9 @@ def save_own_variables(self, store): # Save non-quantized variables all_vars = self._trainable_variables + self._non_trainable_variables + quantized_vars_set = set(quantized_vars) non_quantized_vars = [ - v for v in all_vars if v not in quantized_vars and v.trainable + v for v in all_vars if v not in quantized_vars_set ] for i, v in enumerate(non_quantized_vars): store[f"{i}"] = v @@ -1440,8 +1441,9 @@ def _load_quantized_variables(self, store): # Load non-quantized variables all_vars = self._trainable_variables + self._non_trainable_variables + quantized_vars_set = set(quantized_vars) non_quantized_vars = [ - v for v in all_vars if v not in quantized_vars and v.trainable + v for v in all_vars if v not in quantized_vars_set ] for i, v in enumerate(non_quantized_vars): v.assign(store[f"{i}"]) From 6daa6a5b6c7bd3a55d9315707ac9c2e4828fc3b4 Mon Sep 17 00:00:00 2001 From: Divyashree Sreepathihalli Date: Wed, 23 Jul 2025 20:10:46 -0700 Subject: [PATCH 3/3] change quantized variables to be saved --- keras/src/layers/core/dense.py | 10 ++++++++++ keras/src/layers/core/einsum_dense.py | 10 ++++++++++ keras/src/layers/core/embedding.py | 2 ++ keras/src/layers/layer.py | 2 +- 4 files changed, 23 insertions(+), 1 deletion(-) diff --git a/keras/src/layers/core/dense.py b/keras/src/layers/core/dense.py index 317d6bbd28fc..8db7e1b22472 100644 --- a/keras/src/layers/core/dense.py +++ b/keras/src/layers/core/dense.py @@ -345,12 +345,14 @@ def _int8_build(self, kernel_shape): dtype="int8", trainable=False, ) + self._kernel._is_quantized = True self.kernel_scale = self.add_weight( name="kernel_scale", shape=(self.units,), initializer="ones", trainable=False, ) + self.kernel_scale._is_quantized = True def _int4_build(self, kernel_shape): """Build variables for int4 quantization. @@ -375,6 +377,7 @@ def _int4_build(self, kernel_shape): dtype="int8", trainable=False, ) + self._kernel._is_quantized = True # One scale per output unit (per-channel). self.kernel_scale = self.add_weight( name="kernel_scale", @@ -382,6 +385,7 @@ def _int4_build(self, kernel_shape): initializer="ones", trainable=False, ) + self.kernel_scale._is_quantized = True # Record original input_dim for unpacking at runtime. self._orig_input_dim = input_dim @@ -414,19 +418,25 @@ def _float8_build(self): "overwrite_with_gradient": True, } self.inputs_scale = self.add_weight(name="inputs_scale", **scale_kwargs) + self.inputs_scale._is_quantized = True self.inputs_amax_history = self.add_weight( name="inputs_amax_history", **amax_history_kwargs ) + self.inputs_amax_history._is_quantized = True self.kernel_scale = self.add_weight(name="kernel_scale", **scale_kwargs) + self.kernel_scale._is_quantized = True self.kernel_amax_history = self.add_weight( name="kernel_amax_history", **amax_history_kwargs ) + self.kernel_amax_history._is_quantized = True self.outputs_grad_scale = self.add_weight( name="outputs_grad_scale", **scale_kwargs ) + self.outputs_grad_scale._is_quantized = True self.outputs_grad_amax_history = self.add_weight( name="outputs_grad_amax_history", **amax_history_kwargs ) + self.outputs_grad_amax_history._is_quantized = True def _int8_call(self, inputs, training=None): @ops.custom_gradient diff --git a/keras/src/layers/core/einsum_dense.py b/keras/src/layers/core/einsum_dense.py index 39100486cda8..e2f99217906e 100644 --- a/keras/src/layers/core/einsum_dense.py +++ b/keras/src/layers/core/einsum_dense.py @@ -412,6 +412,7 @@ def _int8_build(self, kernel_shape): dtype="int8", trainable=False, ) + self._kernel._is_quantized = True kernel_scale_shape = self._get_kernel_scale_shape(kernel_shape) self.kernel_scale = self.add_weight( name="kernel_scale", @@ -419,6 +420,7 @@ def _int8_build(self, kernel_shape): initializer="ones", trainable=False, ) + self.kernel_scale._is_quantized = True def _int4_build(self, kernel_shape): """Build variables for int4 quantization. @@ -461,6 +463,7 @@ def _int4_build(self, kernel_shape): dtype="int8", trainable=False, ) + self._kernel._is_quantized = True # Kernel scale kernel_scale_shape = self._get_kernel_scale_shape(kernel_shape) @@ -470,6 +473,7 @@ def _int4_build(self, kernel_shape): initializer="ones", trainable=False, ) + self.kernel_scale._is_quantized = True def _float8_build(self): from keras.src.dtype_policies import QuantizedFloat8DTypePolicy @@ -500,19 +504,25 @@ def _float8_build(self): "overwrite_with_gradient": True, } self.inputs_scale = self.add_weight(name="inputs_scale", **scale_kwargs) + self.inputs_scale._is_quantized = True self.inputs_amax_history = self.add_weight( name="inputs_amax_history", **amax_history_kwargs ) + self.inputs_amax_history._is_quantized = True self.kernel_scale = self.add_weight(name="kernel_scale", **scale_kwargs) + self.kernel_scale._is_quantized = True self.kernel_amax_history = self.add_weight( name="kernel_amax_history", **amax_history_kwargs ) + self.kernel_amax_history._is_quantized = True self.outputs_grad_scale = self.add_weight( name="outputs_grad_scale", **scale_kwargs ) + self.outputs_grad_scale._is_quantized = True self.outputs_grad_amax_history = self.add_weight( name="outputs_grad_amax_history", **amax_history_kwargs ) + self.outputs_grad_amax_history._is_quantized = True def _int8_call(self, inputs, training=None): @ops.custom_gradient diff --git a/keras/src/layers/core/embedding.py b/keras/src/layers/core/embedding.py index e9a207daf3df..b982c6406c0b 100644 --- a/keras/src/layers/core/embedding.py +++ b/keras/src/layers/core/embedding.py @@ -328,6 +328,7 @@ def _int8_build(self, embeddings_shape): dtype="int8", trainable=False, ) + self._embeddings._is_quantized = True # We choose to reduce the axis of `output_dim` because, typically, # `input_dim` is larger than `output_dim`. This reduces quantization # error. @@ -337,6 +338,7 @@ def _int8_build(self, embeddings_shape): initializer="ones", trainable=False, ) + self.embeddings_scale._is_quantized = True def quantized_call(self, *args, **kwargs): if self.quantization_mode != "int8": diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 8d857f9810bd..c7f189df60f5 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -1382,7 +1382,7 @@ def save_own_variables(self, store): def _get_quantized_variables(self): quantized_vars = [] for v in self._trainable_variables + self._non_trainable_variables: - if not backend.is_float_dtype(v.dtype): + if getattr(v, "_is_quantized", False): quantized_vars.append(v) return quantized_vars