Skip to content

fix quantization save and load error #21504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions keras/src/layers/core/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,12 +345,14 @@ def _int8_build(self, kernel_shape):
dtype="int8",
trainable=False,
)
self._kernel._is_quantized = True
self.kernel_scale = self.add_weight(
name="kernel_scale",
shape=(self.units,),
initializer="ones",
trainable=False,
)
self.kernel_scale._is_quantized = True

def _int4_build(self, kernel_shape):
"""Build variables for int4 quantization.
Expand All @@ -375,13 +377,15 @@ def _int4_build(self, kernel_shape):
dtype="int8",
trainable=False,
)
self._kernel._is_quantized = True
# One scale per output unit (per-channel).
self.kernel_scale = self.add_weight(
name="kernel_scale",
shape=(self.units,),
initializer="ones",
trainable=False,
)
self.kernel_scale._is_quantized = True
# Record original input_dim for unpacking at runtime.
self._orig_input_dim = input_dim

Expand Down Expand Up @@ -414,19 +418,25 @@ def _float8_build(self):
"overwrite_with_gradient": True,
}
self.inputs_scale = self.add_weight(name="inputs_scale", **scale_kwargs)
self.inputs_scale._is_quantized = True
self.inputs_amax_history = self.add_weight(
name="inputs_amax_history", **amax_history_kwargs
)
self.inputs_amax_history._is_quantized = True
self.kernel_scale = self.add_weight(name="kernel_scale", **scale_kwargs)
self.kernel_scale._is_quantized = True
self.kernel_amax_history = self.add_weight(
name="kernel_amax_history", **amax_history_kwargs
)
self.kernel_amax_history._is_quantized = True
self.outputs_grad_scale = self.add_weight(
name="outputs_grad_scale", **scale_kwargs
)
self.outputs_grad_scale._is_quantized = True
self.outputs_grad_amax_history = self.add_weight(
name="outputs_grad_amax_history", **amax_history_kwargs
)
self.outputs_grad_amax_history._is_quantized = True

def _int8_call(self, inputs, training=None):
@ops.custom_gradient
Expand Down
10 changes: 10 additions & 0 deletions keras/src/layers/core/einsum_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,13 +412,15 @@ def _int8_build(self, kernel_shape):
dtype="int8",
trainable=False,
)
self._kernel._is_quantized = True
kernel_scale_shape = self._get_kernel_scale_shape(kernel_shape)
self.kernel_scale = self.add_weight(
name="kernel_scale",
shape=kernel_scale_shape,
initializer="ones",
trainable=False,
)
self.kernel_scale._is_quantized = True

def _int4_build(self, kernel_shape):
"""Build variables for int4 quantization.
Expand Down Expand Up @@ -461,6 +463,7 @@ def _int4_build(self, kernel_shape):
dtype="int8",
trainable=False,
)
self._kernel._is_quantized = True

# Kernel scale
kernel_scale_shape = self._get_kernel_scale_shape(kernel_shape)
Expand All @@ -470,6 +473,7 @@ def _int4_build(self, kernel_shape):
initializer="ones",
trainable=False,
)
self.kernel_scale._is_quantized = True

def _float8_build(self):
from keras.src.dtype_policies import QuantizedFloat8DTypePolicy
Expand Down Expand Up @@ -500,19 +504,25 @@ def _float8_build(self):
"overwrite_with_gradient": True,
}
self.inputs_scale = self.add_weight(name="inputs_scale", **scale_kwargs)
self.inputs_scale._is_quantized = True
self.inputs_amax_history = self.add_weight(
name="inputs_amax_history", **amax_history_kwargs
)
self.inputs_amax_history._is_quantized = True
self.kernel_scale = self.add_weight(name="kernel_scale", **scale_kwargs)
self.kernel_scale._is_quantized = True
self.kernel_amax_history = self.add_weight(
name="kernel_amax_history", **amax_history_kwargs
)
self.kernel_amax_history._is_quantized = True
self.outputs_grad_scale = self.add_weight(
name="outputs_grad_scale", **scale_kwargs
)
self.outputs_grad_scale._is_quantized = True
self.outputs_grad_amax_history = self.add_weight(
name="outputs_grad_amax_history", **amax_history_kwargs
)
self.outputs_grad_amax_history._is_quantized = True

def _int8_call(self, inputs, training=None):
@ops.custom_gradient
Expand Down
2 changes: 2 additions & 0 deletions keras/src/layers/core/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,7 @@ def _int8_build(self, embeddings_shape):
dtype="int8",
trainable=False,
)
self._embeddings._is_quantized = True
# We choose to reduce the axis of `output_dim` because, typically,
# `input_dim` is larger than `output_dim`. This reduces quantization
# error.
Expand All @@ -337,6 +338,7 @@ def _int8_build(self, embeddings_shape):
initializer="ones",
trainable=False,
)
self.embeddings_scale._is_quantized = True

def quantized_call(self, *args, **kwargs):
if self.quantization_mode != "int8":
Expand Down
43 changes: 42 additions & 1 deletion keras/src/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,10 +1359,33 @@ def save_own_variables(self, store):
Args:
store: Dict where the state of the model will be saved.
"""
if not getattr(self, "_is_quantized", False):
all_vars = self._trainable_variables + self._non_trainable_variables
for i, v in enumerate(all_vars):
store[f"{i}"] = v
return

# Case: quantized layer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you invert the if? If getattr(self, "_is_quantized", False): ... (more readable)

quantized_vars = self._get_quantized_variables()
for i, v in enumerate(quantized_vars):
store[f"quantized_{i}"] = v

# Save non-quantized variables
all_vars = self._trainable_variables + self._non_trainable_variables
for i, v in enumerate(all_vars):
quantized_vars_set = set(quantized_vars)
non_quantized_vars = [
v for v in all_vars if v not in quantized_vars_set
]
for i, v in enumerate(non_quantized_vars):
store[f"{i}"] = v

def _get_quantized_variables(self):
quantized_vars = []
for v in self._trainable_variables + self._non_trainable_variables:
if getattr(v, "_is_quantized", False):
quantized_vars.append(v)
return quantized_vars

def load_own_variables(self, store):
"""Loads the state of the layer.

Expand All @@ -1372,6 +1395,10 @@ def load_own_variables(self, store):
Args:
store: Dict from which the state of the model will be loaded.
"""
if any(key.startswith("quantized_") for key in store.keys()):
self._load_quantized_variables(store)
return

all_vars = self._trainable_variables + self._non_trainable_variables
if len(store.keys()) != len(all_vars):
if len(all_vars) == 0 and not self.built:
Expand Down Expand Up @@ -1407,6 +1434,20 @@ def load_own_variables(self, store):
for i, v in enumerate(all_vars):
v.assign(store[f"{i}"])

def _load_quantized_variables(self, store):
quantized_vars = self._get_quantized_variables()
for i, v in enumerate(quantized_vars):
v.assign(store[f"quantized_{i}"])

# Load non-quantized variables
all_vars = self._trainable_variables + self._non_trainable_variables
quantized_vars_set = set(quantized_vars)
non_quantized_vars = [
v for v in all_vars if v not in quantized_vars_set
]
for i, v in enumerate(non_quantized_vars):
v.assign(store[f"{i}"])

def _track_variable(self, variable):
if variable.trainable:
self._tracker.add_to_store("trainable_variables", variable)
Expand Down
15 changes: 15 additions & 0 deletions keras/src/layers/layer_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import pickle
from unittest import mock

Expand All @@ -12,6 +13,7 @@
from keras.src import metrics
from keras.src import models
from keras.src import ops
from keras.src import saving
from keras.src import testing
from keras.src.backend.common import global_state
from keras.src.backend.common.remat import RematScope
Expand Down Expand Up @@ -1758,3 +1760,16 @@ def call(self, x):
# foo_mode omitted -> foo_mode defaults to False -> no change
y2 = model(sample_input)
self.assertAllClose(y2, sample_input)

def test_quantized_model_save_and_load(self):
inputs = layers.Input(shape=(None,))
x = layers.Embedding(input_dim=10, output_dim=10)(inputs)
x = layers.Dense(10)(x)
model = models.Model(inputs=inputs, outputs=x)
path = os.path.join(self.get_temp_dir(), "quantized_model.keras")
model.quantize(mode="int8")
model.save(path)

quantized_model = saving.load_model(path)

self.assertTrue(quantized_model.built)