Skip to content

Commit 49797f2

Browse files
Refactor save_own_variables and load_own_variables to use the names as the keys. (#21681)
* Refactor `save_own_variables` and `load_own_variables` to use the names as the keys. * Add int4 quantization to `Embedding` and refactor `save_own_variables` and `load_own_variables` of `EinsumDense` and `Embedding`. * Loosen error threshold. * Rename `MODE_SPEC` to `quantization_variable_spec` and fix `Embedding.embeddings`. * Fix tests for the kernel changes.
1 parent 0102777 commit 49797f2

File tree

7 files changed

+716
-502
lines changed

7 files changed

+716
-502
lines changed

keras/src/layers/core/dense.py

Lines changed: 62 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -235,94 +235,65 @@ def save_own_variables(self, store):
235235
# Do nothing if the layer isn't yet built
236236
if not self.built:
237237
return
238-
# The keys of the `store` will be saved as determined because the
239-
# default ordering will change after quantization
240238
mode = self.quantization_mode
241-
242-
# For int4/int8, the merged LoRA scale (if any) comes from
243-
# `_get_kernel_with_merged_lora()` and is appended below.
244-
MODE_SPEC = {
245-
None: [],
246-
"int4": [],
247-
"int8": [],
248-
"float8": [
249-
"inputs_scale",
250-
"inputs_amax_history",
251-
"kernel_scale",
252-
"kernel_amax_history",
253-
"outputs_grad_scale",
254-
"outputs_grad_amax_history",
255-
],
256-
"gptq": [
257-
"quantized_kernel",
258-
"kernel_scale",
259-
"kernel_zero",
260-
"g_idx",
261-
],
262-
}
263-
264-
if mode not in MODE_SPEC:
239+
if mode not in self.quantization_variable_spec:
265240
raise self._quantization_mode_error(mode)
266241

267242
# Kernel plus optional merged LoRA-aware scale (returns (kernel, None)
268243
# for None/gptq)
269244
kernel_value, merged_kernel_scale = self._get_kernel_with_merged_lora()
270245

271-
targets = []
246+
# Save the variables using the name as the key.
272247
if mode != "gptq":
273-
targets.append(kernel_value)
248+
store["kernel"] = kernel_value
274249
if self.bias is not None:
275-
targets.append(self.bias)
276-
if merged_kernel_scale is not None and mode in ("int4", "int8"):
277-
targets.append(merged_kernel_scale)
278-
279-
# Append per-mode attributes (order matters)
280-
targets.extend(getattr(self, name) for name in MODE_SPEC[mode])
281-
282-
for i, var in enumerate(targets):
283-
store[str(i)] = var
250+
store["bias"] = self.bias
251+
for name in self.quantization_variable_spec[mode]:
252+
if name == "kernel_scale" and mode in ("int4", "int8"):
253+
# For int4/int8, the merged LoRA scale (if any) comes from
254+
# `_get_kernel_with_merged_lora()`
255+
store[name] = merged_kernel_scale
256+
else:
257+
store[name] = getattr(self, name)
284258

285259
def load_own_variables(self, store):
286260
if not self.lora_enabled:
287261
self._check_load_own_variables(store)
288262
# Do nothing if the layer isn't yet built
289263
if not self.built:
290264
return
291-
# The keys of the `store` will be saved as determined because the
292-
# default ordering will change after quantization
293265
mode = self.quantization_mode
266+
if mode not in self.quantization_variable_spec:
267+
raise self._quantization_mode_error(mode)
294268

295-
# Per-mode variable spec (order matters).
296-
MODE_SPEC = {
297-
None: [],
298-
"int8": ["kernel_scale"],
299-
"int4": ["kernel_scale"],
300-
"float8": [
301-
"inputs_scale",
302-
"inputs_amax_history",
303-
"kernel_scale",
304-
"kernel_amax_history",
305-
"outputs_grad_scale",
306-
"outputs_grad_amax_history",
307-
],
308-
"gptq": [
309-
"quantized_kernel",
310-
"kernel_scale",
311-
"kernel_zero",
312-
"g_idx",
313-
],
314-
}
269+
# Determine whether to use the legacy loading method.
270+
if "0" in store:
271+
return self._legacy_load_own_variables(store)
315272

316-
if mode not in MODE_SPEC:
317-
raise self._quantization_mode_error(mode)
273+
# Load the variables using the name as the key.
274+
if mode != "gptq":
275+
self._kernel.assign(store["kernel"])
276+
if self.bias is not None:
277+
self.bias.assign(store["bias"])
278+
for name in self.quantization_variable_spec[mode]:
279+
getattr(self, name).assign(store[name])
280+
if self.lora_enabled:
281+
self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
282+
self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))
318283

284+
def _legacy_load_own_variables(self, store):
285+
# The keys of the `store` will be saved as determined because the
286+
# default ordering will change after quantization
287+
mode = self.quantization_mode
319288
targets = []
320289
if mode != "gptq":
321290
targets.append(self._kernel)
322291
if self.bias is not None:
323292
targets.append(self.bias)
324-
targets.extend(getattr(self, name) for name in MODE_SPEC[mode])
325-
293+
targets.extend(
294+
getattr(self, name)
295+
for name in self.quantization_variable_spec[mode]
296+
)
326297
for i, variable in enumerate(targets):
327298
variable.assign(store[str(i)])
328299
if self.lora_enabled:
@@ -385,6 +356,34 @@ def _check_load_own_variables(self, store):
385356
f"Expected: {[v.name for v in all_vars]}"
386357
)
387358

359+
@property
360+
def quantization_variable_spec(self):
361+
"""Returns a dict mapping quantization modes to variable names.
362+
363+
This spec is used by `save_own_variables` and `load_own_variables` to
364+
determine which variables should be saved/loaded for each quantization
365+
mode.
366+
"""
367+
return {
368+
None: [],
369+
"int8": ["kernel_scale"],
370+
"int4": ["kernel_scale"],
371+
"float8": [
372+
"inputs_scale",
373+
"inputs_amax_history",
374+
"kernel_scale",
375+
"kernel_amax_history",
376+
"outputs_grad_scale",
377+
"outputs_grad_amax_history",
378+
],
379+
"gptq": [
380+
"quantized_kernel",
381+
"kernel_scale",
382+
"kernel_zero",
383+
"g_idx",
384+
],
385+
}
386+
388387
def quantized_build(self, kernel_shape, mode, config=None):
389388
if mode == "int8":
390389
self._int8_build(kernel_shape)

0 commit comments

Comments
 (0)