Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 42 additions & 83 deletions keras/src/layers/core/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,25 +258,25 @@ def save_own_variables(self, store):
if not self.built:
return
mode = self.quantization_mode
if mode not in self.quantization_variable_spec:
if mode not in self.variable_serialization_spec:
raise self._quantization_mode_error(mode)

# Kernel plus optional merged LoRA-aware scale (returns (kernel, None)
# for None/gptq)
kernel_value, merged_kernel_scale = self._get_kernel_with_merged_lora()

# Save the variables using the name as the key.
if mode != "gptq":
store["kernel"] = kernel_value
if self.bias is not None:
store["bias"] = self.bias
for name in self.quantization_variable_spec[mode]:
if name == "kernel_scale" and mode in ("int4", "int8"):
idx = 0
for name in self.variable_serialization_spec[mode]:
if name == "kernel":
store[str(idx)] = kernel_value
elif name == "bias" and self.bias is None:
continue
elif name == "kernel_scale" and mode in ("int4", "int8"):
# For int4/int8, the merged LoRA scale (if any) comes from
# `_get_kernel_with_merged_lora()`
store[name] = merged_kernel_scale
store[str(idx)] = merged_kernel_scale
else:
store[name] = getattr(self, name)
store[str(idx)] = getattr(self, name)
idx += 1

def load_own_variables(self, store):
if not self.lora_enabled:
Expand All @@ -285,39 +285,18 @@ def load_own_variables(self, store):
if not self.built:
return
mode = self.quantization_mode
if mode not in self.quantization_variable_spec:
if mode not in self.variable_serialization_spec:
raise self._quantization_mode_error(mode)

# Determine whether to use the legacy loading method.
if "0" in store:
return self._legacy_load_own_variables(store)

# Load the variables using the name as the key.
if mode != "gptq":
self._kernel.assign(store["kernel"])
if self.bias is not None:
self.bias.assign(store["bias"])
for name in self.quantization_variable_spec[mode]:
getattr(self, name).assign(store[name])
if self.lora_enabled:
self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))

def _legacy_load_own_variables(self, store):
# The keys of the `store` will be saved as determined because the
# default ordering will change after quantization
mode = self.quantization_mode
targets = []
if mode != "gptq":
targets.append(self._kernel)
if self.bias is not None:
targets.append(self.bias)
targets.extend(
getattr(self, name)
for name in self.quantization_variable_spec[mode]
)
for i, variable in enumerate(targets):
variable.assign(store[str(i)])
idx = 0
for name in self.variable_serialization_spec[mode]:
if name == "kernel":
self._kernel.assign(store[str(idx)])
elif name == "bias" and self.bias is None:
continue
else:
getattr(self, name).assign(store[str(idx)])
idx += 1
if self.lora_enabled:
self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))
Expand All @@ -344,53 +323,32 @@ def get_config(self):
config["lora_alpha"] = self.lora_alpha
return {**base_config, **config}

def _check_load_own_variables(self, store):
all_vars = self._trainable_variables + self._non_trainable_variables
if len(store.keys()) != len(all_vars):
if len(all_vars) == 0 and not self.built:
raise ValueError(
f"Layer '{self.name}' was never built "
"and thus it doesn't have any variables. "
f"However the weights file lists {len(store.keys())} "
"variables for this layer.\n"
"In most cases, this error indicates that either:\n\n"
"1. The layer is owned by a parent layer that "
"implements a `build()` method, but calling the "
"parent's `build()` method did NOT create the state of "
f"the child layer '{self.name}'. A `build()` method "
"must create ALL state for the layer, including "
"the state of any children layers.\n\n"
"2. You need to implement "
"the `def build_from_config(self, config)` method "
f"on layer '{self.name}', to specify how to rebuild "
"it during loading. "
"In this case, you might also want to implement the "
"method that generates the build config at saving time, "
"`def get_build_config(self)`. "
"The method `build_from_config()` is meant "
"to create the state "
"of the layer (i.e. its variables) upon deserialization.",
)
raise ValueError(
f"Layer '{self.name}' expected {len(all_vars)} variables, "
"but received "
f"{len(store.keys())} variables during loading. "
f"Expected: {[v.name for v in all_vars]}"
)

@property
def quantization_variable_spec(self):
"""Returns a dict mapping quantization modes to variable names.
def variable_serialization_spec(self):
"""Returns a dict mapping quantization modes to variable names in order.

This spec is used by `save_own_variables` and `load_own_variables` to
determine which variables should be saved/loaded for each quantization
mode.
determine the correct ordering of variables during serialization for
each quantization mode. `None` means no quantization.
"""
return {
None: [],
"int8": ["kernel_scale"],
"int4": ["kernel_scale"],
None: [
"kernel",
"bias",
],
"int8": [
"kernel",
"bias",
"kernel_scale",
],
"int4": [
"kernel",
"bias",
"kernel_scale",
],
"float8": [
"kernel",
"bias",
"inputs_scale",
"inputs_amax_history",
"kernel_scale",
Expand All @@ -399,6 +357,7 @@ def quantization_variable_spec(self):
"outputs_grad_amax_history",
],
"gptq": [
"bias",
"quantized_kernel",
"kernel_scale",
"kernel_zero",
Expand Down
125 changes: 42 additions & 83 deletions keras/src/layers/core/einsum_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,25 +326,25 @@ def save_own_variables(self, store):
if not self.built:
return
mode = self.quantization_mode
if mode not in self.quantization_variable_spec:
if mode not in self.variable_serialization_spec:
raise self._quantization_mode_error(mode)

# Kernel plus optional merged LoRA-aware scale (returns (kernel, None)
# for None/gptq)
kernel_value, merged_kernel_scale = self._get_kernel_with_merged_lora()

# Save the variables using the name as the key.
if mode != "gptq":
store["kernel"] = kernel_value
if self.bias is not None:
store["bias"] = self.bias
for name in self.quantization_variable_spec[mode]:
if name == "kernel_scale" and mode in ("int4", "int8"):
idx = 0
for name in self.variable_serialization_spec[mode]:
if name == "kernel":
store[str(idx)] = kernel_value
elif name == "bias" and self.bias is None:
continue
elif name == "kernel_scale" and mode in ("int4", "int8"):
# For int4/int8, the merged LoRA scale (if any) comes from
# `_get_kernel_with_merged_lora()`
store[name] = merged_kernel_scale
store[str(idx)] = merged_kernel_scale
else:
store[name] = getattr(self, name)
store[str(idx)] = getattr(self, name)
idx += 1

def load_own_variables(self, store):
if not self.lora_enabled:
Expand All @@ -353,39 +353,18 @@ def load_own_variables(self, store):
if not self.built:
return
mode = self.quantization_mode
if mode not in self.quantization_variable_spec:
if mode not in self.variable_serialization_spec:
raise self._quantization_mode_error(mode)

# Determine whether to use the legacy loading method.
if "0" in store:
return self._legacy_load_own_variables(store)

# Load the variables using the name as the key.
if mode != "gptq":
self._kernel.assign(store["kernel"])
if self.bias is not None:
self.bias.assign(store["bias"])
for name in self.quantization_variable_spec[mode]:
getattr(self, name).assign(store[name])
if self.lora_enabled:
self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))

def _legacy_load_own_variables(self, store):
# The keys of the `store` will be saved as determined because the
# default ordering will change after quantization
mode = self.quantization_mode
targets = []
if mode != "gptq":
targets.append(self._kernel)
if self.bias is not None:
targets.append(self.bias)
targets.extend(
getattr(self, name)
for name in self.quantization_variable_spec[mode]
)
for i, variable in enumerate(targets):
variable.assign(store[str(i)])
idx = 0
for name in self.variable_serialization_spec[mode]:
if name == "kernel":
self._kernel.assign(store[str(idx)])
elif name == "bias" and self.bias is None:
continue
else:
getattr(self, name).assign(store[str(idx)])
idx += 1
if self.lora_enabled:
self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))
Expand Down Expand Up @@ -418,53 +397,32 @@ def get_config(self):
config["gptq_unpacked_column_size"] = self.gptq_unpacked_column_size
return {**base_config, **config}

def _check_load_own_variables(self, store):
all_vars = self._trainable_variables + self._non_trainable_variables
if len(store.keys()) != len(all_vars):
if len(all_vars) == 0 and not self.built:
raise ValueError(
f"Layer '{self.name}' was never built "
"and thus it doesn't have any variables. "
f"However the weights file lists {len(store.keys())} "
"variables for this layer.\n"
"In most cases, this error indicates that either:\n\n"
"1. The layer is owned by a parent layer that "
"implements a `build()` method, but calling the "
"parent's `build()` method did NOT create the state of "
f"the child layer '{self.name}'. A `build()` method "
"must create ALL state for the layer, including "
"the state of any children layers.\n\n"
"2. You need to implement "
"the `def build_from_config(self, config)` method "
f"on layer '{self.name}', to specify how to rebuild "
"it during loading. "
"In this case, you might also want to implement the "
"method that generates the build config at saving time, "
"`def get_build_config(self)`. "
"The method `build_from_config()` is meant "
"to create the state "
"of the layer (i.e. its variables) upon deserialization.",
)
raise ValueError(
f"Layer '{self.name}' expected {len(all_vars)} variables, "
"but received "
f"{len(store.keys())} variables during loading. "
f"Expected: {[v.name for v in all_vars]}"
)

@property
def quantization_variable_spec(self):
"""Returns a dict mapping quantization modes to variable names.
def variable_serialization_spec(self):
"""Returns a dict mapping quantization modes to variable names in order.

This spec is used by `save_own_variables` and `load_own_variables` to
determine which variables should be saved/loaded for each quantization
mode.
determine the correct ordering of variables during serialization for
each quantization mode. `None` means no quantization.
"""
return {
None: [],
"int8": ["kernel_scale"],
"int4": ["kernel_scale"],
None: [
"kernel",
"bias",
],
"int8": [
"kernel",
"bias",
"kernel_scale",
],
"int4": [
"kernel",
"bias",
"kernel_scale",
],
"float8": [
"kernel",
"bias",
"inputs_scale",
"inputs_amax_history",
"kernel_scale",
Expand All @@ -473,6 +431,7 @@ def quantization_variable_spec(self):
"outputs_grad_amax_history",
],
"gptq": [
"bias",
"quantized_kernel",
"kernel_scale",
"kernel_zero",
Expand Down
Loading