From 9615b5bf7c81e62f0464ca8de192f83e6dd5dcfc Mon Sep 17 00:00:00 2001 From: Hongyu Chiu <20734616+james77777778@users.noreply.github.com> Date: Fri, 3 Oct 2025 16:32:59 +0800 Subject: [PATCH] Refactor variable serialization. --- keras/src/layers/core/dense.py | 125 +++++++++----------------- keras/src/layers/core/einsum_dense.py | 125 +++++++++----------------- keras/src/layers/core/embedding.py | 111 +++++++---------------- keras/src/layers/layer.py | 22 +++-- keras/src/saving/file_editor_test.py | 12 +-- 5 files changed, 136 insertions(+), 259 deletions(-) diff --git a/keras/src/layers/core/dense.py b/keras/src/layers/core/dense.py index 7eedbbcc8783..01b6f150fe7f 100644 --- a/keras/src/layers/core/dense.py +++ b/keras/src/layers/core/dense.py @@ -258,25 +258,25 @@ def save_own_variables(self, store): if not self.built: return mode = self.quantization_mode - if mode not in self.quantization_variable_spec: + if mode not in self.variable_serialization_spec: raise self._quantization_mode_error(mode) # Kernel plus optional merged LoRA-aware scale (returns (kernel, None) # for None/gptq) kernel_value, merged_kernel_scale = self._get_kernel_with_merged_lora() - - # Save the variables using the name as the key. - if mode != "gptq": - store["kernel"] = kernel_value - if self.bias is not None: - store["bias"] = self.bias - for name in self.quantization_variable_spec[mode]: - if name == "kernel_scale" and mode in ("int4", "int8"): + idx = 0 + for name in self.variable_serialization_spec[mode]: + if name == "kernel": + store[str(idx)] = kernel_value + elif name == "bias" and self.bias is None: + continue + elif name == "kernel_scale" and mode in ("int4", "int8"): # For int4/int8, the merged LoRA scale (if any) comes from # `_get_kernel_with_merged_lora()` - store[name] = merged_kernel_scale + store[str(idx)] = merged_kernel_scale else: - store[name] = getattr(self, name) + store[str(idx)] = getattr(self, name) + idx += 1 def load_own_variables(self, store): if not self.lora_enabled: @@ -285,39 +285,18 @@ def load_own_variables(self, store): if not self.built: return mode = self.quantization_mode - if mode not in self.quantization_variable_spec: + if mode not in self.variable_serialization_spec: raise self._quantization_mode_error(mode) - # Determine whether to use the legacy loading method. - if "0" in store: - return self._legacy_load_own_variables(store) - - # Load the variables using the name as the key. - if mode != "gptq": - self._kernel.assign(store["kernel"]) - if self.bias is not None: - self.bias.assign(store["bias"]) - for name in self.quantization_variable_spec[mode]: - getattr(self, name).assign(store[name]) - if self.lora_enabled: - self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape)) - self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape)) - - def _legacy_load_own_variables(self, store): - # The keys of the `store` will be saved as determined because the - # default ordering will change after quantization - mode = self.quantization_mode - targets = [] - if mode != "gptq": - targets.append(self._kernel) - if self.bias is not None: - targets.append(self.bias) - targets.extend( - getattr(self, name) - for name in self.quantization_variable_spec[mode] - ) - for i, variable in enumerate(targets): - variable.assign(store[str(i)]) + idx = 0 + for name in self.variable_serialization_spec[mode]: + if name == "kernel": + self._kernel.assign(store[str(idx)]) + elif name == "bias" and self.bias is None: + continue + else: + getattr(self, name).assign(store[str(idx)]) + idx += 1 if self.lora_enabled: self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape)) self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape)) @@ -344,53 +323,32 @@ def get_config(self): config["lora_alpha"] = self.lora_alpha return {**base_config, **config} - def _check_load_own_variables(self, store): - all_vars = self._trainable_variables + self._non_trainable_variables - if len(store.keys()) != len(all_vars): - if len(all_vars) == 0 and not self.built: - raise ValueError( - f"Layer '{self.name}' was never built " - "and thus it doesn't have any variables. " - f"However the weights file lists {len(store.keys())} " - "variables for this layer.\n" - "In most cases, this error indicates that either:\n\n" - "1. The layer is owned by a parent layer that " - "implements a `build()` method, but calling the " - "parent's `build()` method did NOT create the state of " - f"the child layer '{self.name}'. A `build()` method " - "must create ALL state for the layer, including " - "the state of any children layers.\n\n" - "2. You need to implement " - "the `def build_from_config(self, config)` method " - f"on layer '{self.name}', to specify how to rebuild " - "it during loading. " - "In this case, you might also want to implement the " - "method that generates the build config at saving time, " - "`def get_build_config(self)`. " - "The method `build_from_config()` is meant " - "to create the state " - "of the layer (i.e. its variables) upon deserialization.", - ) - raise ValueError( - f"Layer '{self.name}' expected {len(all_vars)} variables, " - "but received " - f"{len(store.keys())} variables during loading. " - f"Expected: {[v.name for v in all_vars]}" - ) - @property - def quantization_variable_spec(self): - """Returns a dict mapping quantization modes to variable names. + def variable_serialization_spec(self): + """Returns a dict mapping quantization modes to variable names in order. This spec is used by `save_own_variables` and `load_own_variables` to - determine which variables should be saved/loaded for each quantization - mode. + determine the correct ordering of variables during serialization for + each quantization mode. `None` means no quantization. """ return { - None: [], - "int8": ["kernel_scale"], - "int4": ["kernel_scale"], + None: [ + "kernel", + "bias", + ], + "int8": [ + "kernel", + "bias", + "kernel_scale", + ], + "int4": [ + "kernel", + "bias", + "kernel_scale", + ], "float8": [ + "kernel", + "bias", "inputs_scale", "inputs_amax_history", "kernel_scale", @@ -399,6 +357,7 @@ def quantization_variable_spec(self): "outputs_grad_amax_history", ], "gptq": [ + "bias", "quantized_kernel", "kernel_scale", "kernel_zero", diff --git a/keras/src/layers/core/einsum_dense.py b/keras/src/layers/core/einsum_dense.py index 2c8f2e2d90d6..37639461d7a9 100644 --- a/keras/src/layers/core/einsum_dense.py +++ b/keras/src/layers/core/einsum_dense.py @@ -326,25 +326,25 @@ def save_own_variables(self, store): if not self.built: return mode = self.quantization_mode - if mode not in self.quantization_variable_spec: + if mode not in self.variable_serialization_spec: raise self._quantization_mode_error(mode) # Kernel plus optional merged LoRA-aware scale (returns (kernel, None) # for None/gptq) kernel_value, merged_kernel_scale = self._get_kernel_with_merged_lora() - - # Save the variables using the name as the key. - if mode != "gptq": - store["kernel"] = kernel_value - if self.bias is not None: - store["bias"] = self.bias - for name in self.quantization_variable_spec[mode]: - if name == "kernel_scale" and mode in ("int4", "int8"): + idx = 0 + for name in self.variable_serialization_spec[mode]: + if name == "kernel": + store[str(idx)] = kernel_value + elif name == "bias" and self.bias is None: + continue + elif name == "kernel_scale" and mode in ("int4", "int8"): # For int4/int8, the merged LoRA scale (if any) comes from # `_get_kernel_with_merged_lora()` - store[name] = merged_kernel_scale + store[str(idx)] = merged_kernel_scale else: - store[name] = getattr(self, name) + store[str(idx)] = getattr(self, name) + idx += 1 def load_own_variables(self, store): if not self.lora_enabled: @@ -353,39 +353,18 @@ def load_own_variables(self, store): if not self.built: return mode = self.quantization_mode - if mode not in self.quantization_variable_spec: + if mode not in self.variable_serialization_spec: raise self._quantization_mode_error(mode) - # Determine whether to use the legacy loading method. - if "0" in store: - return self._legacy_load_own_variables(store) - - # Load the variables using the name as the key. - if mode != "gptq": - self._kernel.assign(store["kernel"]) - if self.bias is not None: - self.bias.assign(store["bias"]) - for name in self.quantization_variable_spec[mode]: - getattr(self, name).assign(store[name]) - if self.lora_enabled: - self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape)) - self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape)) - - def _legacy_load_own_variables(self, store): - # The keys of the `store` will be saved as determined because the - # default ordering will change after quantization - mode = self.quantization_mode - targets = [] - if mode != "gptq": - targets.append(self._kernel) - if self.bias is not None: - targets.append(self.bias) - targets.extend( - getattr(self, name) - for name in self.quantization_variable_spec[mode] - ) - for i, variable in enumerate(targets): - variable.assign(store[str(i)]) + idx = 0 + for name in self.variable_serialization_spec[mode]: + if name == "kernel": + self._kernel.assign(store[str(idx)]) + elif name == "bias" and self.bias is None: + continue + else: + getattr(self, name).assign(store[str(idx)]) + idx += 1 if self.lora_enabled: self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape)) self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape)) @@ -418,53 +397,32 @@ def get_config(self): config["gptq_unpacked_column_size"] = self.gptq_unpacked_column_size return {**base_config, **config} - def _check_load_own_variables(self, store): - all_vars = self._trainable_variables + self._non_trainable_variables - if len(store.keys()) != len(all_vars): - if len(all_vars) == 0 and not self.built: - raise ValueError( - f"Layer '{self.name}' was never built " - "and thus it doesn't have any variables. " - f"However the weights file lists {len(store.keys())} " - "variables for this layer.\n" - "In most cases, this error indicates that either:\n\n" - "1. The layer is owned by a parent layer that " - "implements a `build()` method, but calling the " - "parent's `build()` method did NOT create the state of " - f"the child layer '{self.name}'. A `build()` method " - "must create ALL state for the layer, including " - "the state of any children layers.\n\n" - "2. You need to implement " - "the `def build_from_config(self, config)` method " - f"on layer '{self.name}', to specify how to rebuild " - "it during loading. " - "In this case, you might also want to implement the " - "method that generates the build config at saving time, " - "`def get_build_config(self)`. " - "The method `build_from_config()` is meant " - "to create the state " - "of the layer (i.e. its variables) upon deserialization.", - ) - raise ValueError( - f"Layer '{self.name}' expected {len(all_vars)} variables, " - "but received " - f"{len(store.keys())} variables during loading. " - f"Expected: {[v.name for v in all_vars]}" - ) - @property - def quantization_variable_spec(self): - """Returns a dict mapping quantization modes to variable names. + def variable_serialization_spec(self): + """Returns a dict mapping quantization modes to variable names in order. This spec is used by `save_own_variables` and `load_own_variables` to - determine which variables should be saved/loaded for each quantization - mode. + determine the correct ordering of variables during serialization for + each quantization mode. `None` means no quantization. """ return { - None: [], - "int8": ["kernel_scale"], - "int4": ["kernel_scale"], + None: [ + "kernel", + "bias", + ], + "int8": [ + "kernel", + "bias", + "kernel_scale", + ], + "int4": [ + "kernel", + "bias", + "kernel_scale", + ], "float8": [ + "kernel", + "bias", "inputs_scale", "inputs_amax_history", "kernel_scale", @@ -473,6 +431,7 @@ def quantization_variable_spec(self): "outputs_grad_amax_history", ], "gptq": [ + "bias", "quantized_kernel", "kernel_scale", "kernel_zero", diff --git a/keras/src/layers/core/embedding.py b/keras/src/layers/core/embedding.py index aa809be63f34..c1cb3b6b0117 100644 --- a/keras/src/layers/core/embedding.py +++ b/keras/src/layers/core/embedding.py @@ -218,24 +218,25 @@ def save_own_variables(self, store): if not self.built: return mode = self.quantization_mode - if mode not in self.quantization_variable_spec: + if mode not in self.variable_serialization_spec: raise self._quantization_mode_error(mode) # Embeddings plus optional merged LoRA-aware scale - # (returns (kernel, None) for None/gptq). + # (returns (embeddings, None) for `None` mode). embeddings_value, merged_kernel_scale = ( self._get_embeddings_with_merged_lora() ) - - # Save the variables using the name as the key. - store["embeddings"] = embeddings_value - for name in self.quantization_variable_spec[mode]: - if name == "embeddings_scale" and mode in ("int4", "int8"): + idx = 0 + for name in self.variable_serialization_spec[mode]: + if name == "embeddings": + store[str(idx)] = embeddings_value + elif name == "embeddings_scale" and mode in ("int4", "int8"): # For int4/int8, the merged LoRA scale (if any) comes from # `_get_embeddings_with_merged_lora()` - store[name] = merged_kernel_scale + store[str(idx)] = merged_kernel_scale else: - store[name] = getattr(self, name) + store[str(idx)] = getattr(self, name) + idx += 1 def load_own_variables(self, store): if not self.lora_enabled: @@ -244,36 +245,16 @@ def load_own_variables(self, store): if not self.built: return mode = self.quantization_mode - if mode not in self.quantization_variable_spec: + if mode not in self.variable_serialization_spec: raise self._quantization_mode_error(mode) - # Determine whether to use the legacy loading method. - if "0" in store: - return self._legacy_load_own_variables(store) - - # Load the variables using the name as the key. - self._embeddings.assign(store["embeddings"]) - for name in self.quantization_variable_spec[mode]: - getattr(self, name).assign(store[name]) - if self.lora_enabled: - self.lora_embeddings_a.assign( - ops.zeros(self.lora_embeddings_a.shape) - ) - self.lora_embeddings_b.assign( - ops.zeros(self.lora_embeddings_b.shape) - ) - - def _legacy_load_own_variables(self, store): - # The keys of the `store` will be saved as determined because the - # default ordering will change after quantization - mode = self.quantization_mode - targets = [self._embeddings] - targets.extend( - getattr(self, name) - for name in self.quantization_variable_spec[mode] - ) - for i, variable in enumerate(targets): - variable.assign(store[str(i)]) + idx = 0 + for name in self.variable_serialization_spec[mode]: + if name == "embeddings": + self._embeddings.assign(store[str(idx)]) + else: + getattr(self, name).assign(store[str(idx)]) + idx += 1 if self.lora_enabled: self.lora_embeddings_a.assign( ops.zeros(self.lora_embeddings_a.shape) @@ -306,40 +287,6 @@ def get_config(self): config["lora_alpha"] = self.lora_alpha return {**base_config, **config} - def _check_load_own_variables(self, store): - all_vars = self._trainable_variables + self._non_trainable_variables - if len(store.keys()) != len(all_vars): - if len(all_vars) == 0 and not self.built: - raise ValueError( - f"Layer '{self.name}' was never built " - "and thus it doesn't have any variables. " - f"However the weights file lists {len(store.keys())} " - "variables for this layer.\n" - "In most cases, this error indicates that either:\n\n" - "1. The layer is owned by a parent layer that " - "implements a `build()` method, but calling the " - "parent's `build()` method did NOT create the state of " - f"the child layer '{self.name}'. A `build()` method " - "must create ALL state for the layer, including " - "the state of any children layers.\n\n" - "2. You need to implement " - "the `def build_from_config(self, config)` method " - f"on layer '{self.name}', to specify how to rebuild " - "it during loading. " - "In this case, you might also want to implement the " - "method that generates the build config at saving time, " - "`def get_build_config(self)`. " - "The method `build_from_config()` is meant " - "to create the state " - "of the layer (i.e. its variables) upon deserialization.", - ) - raise ValueError( - f"Layer '{self.name}' expected {len(all_vars)} variables, " - "but received " - f"{len(store.keys())} variables during loading. " - f"Expected: {[v.name for v in all_vars]}" - ) - def _quantization_mode_error(self, mode): return NotImplementedError( "Invalid quantization mode. Expected one of ('int8', 'int4'). " @@ -347,17 +294,25 @@ def _quantization_mode_error(self, mode): ) @property - def quantization_variable_spec(self): - """Returns a dict mapping quantization modes to variable names. + def variable_serialization_spec(self): + """Returns a dict mapping quantization modes to variable names in order. This spec is used by `save_own_variables` and `load_own_variables` to - determine which variables should be saved/loaded for each quantization - mode. + determine the correct ordering of variables during serialization for + each quantization mode. `None` means no quantization. """ return { - None: [], - "int8": ["embeddings_scale"], - "int4": ["embeddings_scale"], + None: [ + "embeddings", + ], + "int8": [ + "embeddings", + "embeddings_scale", + ], + "int4": [ + "embeddings", + "embeddings_scale", + ], } def quantized_build(self, embeddings_shape, mode): diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index 11e4046c7b8a..9e6c928e3ee4 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -1368,15 +1368,7 @@ def save_own_variables(self, store): for i, v in enumerate(all_vars): store[f"{i}"] = v - def load_own_variables(self, store): - """Loads the state of the layer. - - You can override this method to take full control of how the state of - the layer is loaded upon calling `keras.models.load_model()`. - - Args: - store: Dict from which the state of the model will be loaded. - """ + def _check_load_own_variables(self, store): all_vars = self._trainable_variables + self._non_trainable_variables if len(store.keys()) != len(all_vars): if len(all_vars) == 0 and not self.built: @@ -1409,6 +1401,18 @@ def load_own_variables(self, store): f"{len(store.keys())} variables during loading. " f"Expected: {[v.name for v in all_vars]}" ) + + def load_own_variables(self, store): + """Loads the state of the layer. + + You can override this method to take full control of how the state of + the layer is loaded upon calling `keras.models.load_model()`. + + Args: + store: Dict from which the state of the model will be loaded. + """ + self._check_load_own_variables(store) + all_vars = self._trainable_variables + self._non_trainable_variables for i, v in enumerate(all_vars): v.assign(store[f"{i}"]) diff --git a/keras/src/saving/file_editor_test.py b/keras/src/saving/file_editor_test.py index f02ca11516b1..965c97ba863d 100644 --- a/keras/src/saving/file_editor_test.py +++ b/keras/src/saving/file_editor_test.py @@ -42,7 +42,7 @@ def test_basics(self): out = editor.compare(target_model) # Fails editor.add_object( - "layers/dense_3", weights={"kernel": np.random.random((3, 3))} + "layers/dense_3", weights={"0": np.random.random((3, 3))} ) out = editor.compare(target_model) # Fails self.assertEqual(out["status"], "error") @@ -50,7 +50,7 @@ def test_basics(self): editor.rename_object("dense_3", "dense_4") editor.rename_object("layers/dense_4", "dense_2") - editor.add_weights("dense_2", weights={"bias": np.random.random((3,))}) + editor.add_weights("dense_2", weights={"1": np.random.random((3,))}) out = editor.compare(target_model) # Succeeds self.assertEqual(out["status"], "success") @@ -75,18 +75,18 @@ def test_basics(self): out = editor.compare(target_model) # Succeeds self.assertEqual(out["status"], "success") - editor.delete_weight("dense_2", "bias") + editor.delete_weight("dense_2", "1") out = editor.compare(target_model) # Fails self.assertEqual(out["status"], "error") self.assertEqual(out["error_count"], 1) - editor.add_weights("dense_2", {"bias": np.zeros((7,))}) + editor.add_weights("dense_2", {"1": np.zeros((7,))}) out = editor.compare(target_model) # Fails self.assertEqual(out["status"], "error") self.assertEqual(out["error_count"], 1) - editor.delete_weight("dense_2", "bias") - editor.add_weights("dense_2", {"bias": np.zeros((3,))}) + editor.delete_weight("dense_2", "1") + editor.add_weights("dense_2", {"1": np.zeros((3,))}) out = editor.compare(target_model) # Succeeds self.assertEqual(out["status"], "success")