Skip to content

Commit d52caae

Browse files
committed
Refactor variable serialization.
1 parent 3fac66f commit d52caae

File tree

5 files changed

+136
-259
lines changed

5 files changed

+136
-259
lines changed

keras/src/layers/core/dense.py

Lines changed: 42 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -258,25 +258,25 @@ def save_own_variables(self, store):
258258
if not self.built:
259259
return
260260
mode = self.quantization_mode
261-
if mode not in self.quantization_variable_spec:
261+
if mode not in self.variable_serialization_spec:
262262
raise self._quantization_mode_error(mode)
263263

264264
# Kernel plus optional merged LoRA-aware scale (returns (kernel, None)
265265
# for None/gptq)
266266
kernel_value, merged_kernel_scale = self._get_kernel_with_merged_lora()
267-
268-
# Save the variables using the name as the key.
269-
if mode != "gptq":
270-
store["kernel"] = kernel_value
271-
if self.bias is not None:
272-
store["bias"] = self.bias
273-
for name in self.quantization_variable_spec[mode]:
274-
if name == "kernel_scale" and mode in ("int4", "int8"):
267+
idx = 0
268+
for name in self.variable_serialization_spec[mode]:
269+
if name == "kernel":
270+
store[str(idx)] = kernel_value
271+
elif name == "bias" and self.bias is None:
272+
continue
273+
elif name == "kernel_scale" and mode in ("int4", "int8"):
275274
# For int4/int8, the merged LoRA scale (if any) comes from
276275
# `_get_kernel_with_merged_lora()`
277-
store[name] = merged_kernel_scale
276+
store[str(idx)] = merged_kernel_scale
278277
else:
279-
store[name] = getattr(self, name)
278+
store[str(idx)] = getattr(self, name)
279+
idx += 1
280280

281281
def load_own_variables(self, store):
282282
if not self.lora_enabled:
@@ -285,39 +285,18 @@ def load_own_variables(self, store):
285285
if not self.built:
286286
return
287287
mode = self.quantization_mode
288-
if mode not in self.quantization_variable_spec:
288+
if mode not in self.variable_serialization_spec:
289289
raise self._quantization_mode_error(mode)
290290

291-
# Determine whether to use the legacy loading method.
292-
if "0" in store:
293-
return self._legacy_load_own_variables(store)
294-
295-
# Load the variables using the name as the key.
296-
if mode != "gptq":
297-
self._kernel.assign(store["kernel"])
298-
if self.bias is not None:
299-
self.bias.assign(store["bias"])
300-
for name in self.quantization_variable_spec[mode]:
301-
getattr(self, name).assign(store[name])
302-
if self.lora_enabled:
303-
self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
304-
self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))
305-
306-
def _legacy_load_own_variables(self, store):
307-
# The keys of the `store` will be saved as determined because the
308-
# default ordering will change after quantization
309-
mode = self.quantization_mode
310-
targets = []
311-
if mode != "gptq":
312-
targets.append(self._kernel)
313-
if self.bias is not None:
314-
targets.append(self.bias)
315-
targets.extend(
316-
getattr(self, name)
317-
for name in self.quantization_variable_spec[mode]
318-
)
319-
for i, variable in enumerate(targets):
320-
variable.assign(store[str(i)])
291+
idx = 0
292+
for name in self.variable_serialization_spec[mode]:
293+
if name == "kernel":
294+
self._kernel.assign(store[str(idx)])
295+
elif name == "bias" and self.bias is None:
296+
continue
297+
else:
298+
getattr(self, name).assign(store[str(idx)])
299+
idx += 1
321300
if self.lora_enabled:
322301
self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
323302
self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))
@@ -344,53 +323,32 @@ def get_config(self):
344323
config["lora_alpha"] = self.lora_alpha
345324
return {**base_config, **config}
346325

347-
def _check_load_own_variables(self, store):
348-
all_vars = self._trainable_variables + self._non_trainable_variables
349-
if len(store.keys()) != len(all_vars):
350-
if len(all_vars) == 0 and not self.built:
351-
raise ValueError(
352-
f"Layer '{self.name}' was never built "
353-
"and thus it doesn't have any variables. "
354-
f"However the weights file lists {len(store.keys())} "
355-
"variables for this layer.\n"
356-
"In most cases, this error indicates that either:\n\n"
357-
"1. The layer is owned by a parent layer that "
358-
"implements a `build()` method, but calling the "
359-
"parent's `build()` method did NOT create the state of "
360-
f"the child layer '{self.name}'. A `build()` method "
361-
"must create ALL state for the layer, including "
362-
"the state of any children layers.\n\n"
363-
"2. You need to implement "
364-
"the `def build_from_config(self, config)` method "
365-
f"on layer '{self.name}', to specify how to rebuild "
366-
"it during loading. "
367-
"In this case, you might also want to implement the "
368-
"method that generates the build config at saving time, "
369-
"`def get_build_config(self)`. "
370-
"The method `build_from_config()` is meant "
371-
"to create the state "
372-
"of the layer (i.e. its variables) upon deserialization.",
373-
)
374-
raise ValueError(
375-
f"Layer '{self.name}' expected {len(all_vars)} variables, "
376-
"but received "
377-
f"{len(store.keys())} variables during loading. "
378-
f"Expected: {[v.name for v in all_vars]}"
379-
)
380-
381326
@property
382-
def quantization_variable_spec(self):
383-
"""Returns a dict mapping quantization modes to variable names.
327+
def variable_serialization_spec(self):
328+
"""Returns a dict mapping quantization modes to variable names in order.
384329
385330
This spec is used by `save_own_variables` and `load_own_variables` to
386-
determine which variables should be saved/loaded for each quantization
387-
mode.
331+
determine the correct ordering of variables during serialization for
332+
each quantization mode. `None` means no quantization.
388333
"""
389334
return {
390-
None: [],
391-
"int8": ["kernel_scale"],
392-
"int4": ["kernel_scale"],
335+
None: [
336+
"kernel",
337+
"bias",
338+
],
339+
"int8": [
340+
"kernel",
341+
"bias",
342+
"kernel_scale",
343+
],
344+
"int4": [
345+
"kernel",
346+
"bias",
347+
"kernel_scale",
348+
],
393349
"float8": [
350+
"kernel",
351+
"bias",
394352
"inputs_scale",
395353
"inputs_amax_history",
396354
"kernel_scale",
@@ -399,6 +357,7 @@ def quantization_variable_spec(self):
399357
"outputs_grad_amax_history",
400358
],
401359
"gptq": [
360+
"bias",
402361
"quantized_kernel",
403362
"kernel_scale",
404363
"kernel_zero",

keras/src/layers/core/einsum_dense.py

Lines changed: 42 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -326,25 +326,25 @@ def save_own_variables(self, store):
326326
if not self.built:
327327
return
328328
mode = self.quantization_mode
329-
if mode not in self.quantization_variable_spec:
329+
if mode not in self.variable_serialization_spec:
330330
raise self._quantization_mode_error(mode)
331331

332332
# Kernel plus optional merged LoRA-aware scale (returns (kernel, None)
333333
# for None/gptq)
334334
kernel_value, merged_kernel_scale = self._get_kernel_with_merged_lora()
335-
336-
# Save the variables using the name as the key.
337-
if mode != "gptq":
338-
store["kernel"] = kernel_value
339-
if self.bias is not None:
340-
store["bias"] = self.bias
341-
for name in self.quantization_variable_spec[mode]:
342-
if name == "kernel_scale" and mode in ("int4", "int8"):
335+
idx = 0
336+
for name in self.variable_serialization_spec[mode]:
337+
if name == "kernel":
338+
store[str(idx)] = kernel_value
339+
elif name == "bias" and self.bias is None:
340+
continue
341+
elif name == "kernel_scale" and mode in ("int4", "int8"):
343342
# For int4/int8, the merged LoRA scale (if any) comes from
344343
# `_get_kernel_with_merged_lora()`
345-
store[name] = merged_kernel_scale
344+
store[str(idx)] = merged_kernel_scale
346345
else:
347-
store[name] = getattr(self, name)
346+
store[str(idx)] = getattr(self, name)
347+
idx += 1
348348

349349
def load_own_variables(self, store):
350350
if not self.lora_enabled:
@@ -353,39 +353,18 @@ def load_own_variables(self, store):
353353
if not self.built:
354354
return
355355
mode = self.quantization_mode
356-
if mode not in self.quantization_variable_spec:
356+
if mode not in self.variable_serialization_spec:
357357
raise self._quantization_mode_error(mode)
358358

359-
# Determine whether to use the legacy loading method.
360-
if "0" in store:
361-
return self._legacy_load_own_variables(store)
362-
363-
# Load the variables using the name as the key.
364-
if mode != "gptq":
365-
self._kernel.assign(store["kernel"])
366-
if self.bias is not None:
367-
self.bias.assign(store["bias"])
368-
for name in self.quantization_variable_spec[mode]:
369-
getattr(self, name).assign(store[name])
370-
if self.lora_enabled:
371-
self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
372-
self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))
373-
374-
def _legacy_load_own_variables(self, store):
375-
# The keys of the `store` will be saved as determined because the
376-
# default ordering will change after quantization
377-
mode = self.quantization_mode
378-
targets = []
379-
if mode != "gptq":
380-
targets.append(self._kernel)
381-
if self.bias is not None:
382-
targets.append(self.bias)
383-
targets.extend(
384-
getattr(self, name)
385-
for name in self.quantization_variable_spec[mode]
386-
)
387-
for i, variable in enumerate(targets):
388-
variable.assign(store[str(i)])
359+
idx = 0
360+
for name in self.variable_serialization_spec[mode]:
361+
if name == "kernel":
362+
self._kernel.assign(store[str(idx)])
363+
elif name == "bias" and self.bias is None:
364+
continue
365+
else:
366+
getattr(self, name).assign(store[str(idx)])
367+
idx += 1
389368
if self.lora_enabled:
390369
self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
391370
self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))
@@ -418,53 +397,32 @@ def get_config(self):
418397
config["gptq_unpacked_column_size"] = self.gptq_unpacked_column_size
419398
return {**base_config, **config}
420399

421-
def _check_load_own_variables(self, store):
422-
all_vars = self._trainable_variables + self._non_trainable_variables
423-
if len(store.keys()) != len(all_vars):
424-
if len(all_vars) == 0 and not self.built:
425-
raise ValueError(
426-
f"Layer '{self.name}' was never built "
427-
"and thus it doesn't have any variables. "
428-
f"However the weights file lists {len(store.keys())} "
429-
"variables for this layer.\n"
430-
"In most cases, this error indicates that either:\n\n"
431-
"1. The layer is owned by a parent layer that "
432-
"implements a `build()` method, but calling the "
433-
"parent's `build()` method did NOT create the state of "
434-
f"the child layer '{self.name}'. A `build()` method "
435-
"must create ALL state for the layer, including "
436-
"the state of any children layers.\n\n"
437-
"2. You need to implement "
438-
"the `def build_from_config(self, config)` method "
439-
f"on layer '{self.name}', to specify how to rebuild "
440-
"it during loading. "
441-
"In this case, you might also want to implement the "
442-
"method that generates the build config at saving time, "
443-
"`def get_build_config(self)`. "
444-
"The method `build_from_config()` is meant "
445-
"to create the state "
446-
"of the layer (i.e. its variables) upon deserialization.",
447-
)
448-
raise ValueError(
449-
f"Layer '{self.name}' expected {len(all_vars)} variables, "
450-
"but received "
451-
f"{len(store.keys())} variables during loading. "
452-
f"Expected: {[v.name for v in all_vars]}"
453-
)
454-
455400
@property
456-
def quantization_variable_spec(self):
457-
"""Returns a dict mapping quantization modes to variable names.
401+
def variable_serialization_spec(self):
402+
"""Returns a dict mapping quantization modes to variable names in order.
458403
459404
This spec is used by `save_own_variables` and `load_own_variables` to
460-
determine which variables should be saved/loaded for each quantization
461-
mode.
405+
determine the correct ordering of variables during serialization for
406+
each quantization mode. `None` means no quantization.
462407
"""
463408
return {
464-
None: [],
465-
"int8": ["kernel_scale"],
466-
"int4": ["kernel_scale"],
409+
None: [
410+
"kernel",
411+
"bias",
412+
],
413+
"int8": [
414+
"kernel",
415+
"bias",
416+
"kernel_scale",
417+
],
418+
"int4": [
419+
"kernel",
420+
"bias",
421+
"kernel_scale",
422+
],
467423
"float8": [
424+
"kernel",
425+
"bias",
468426
"inputs_scale",
469427
"inputs_amax_history",
470428
"kernel_scale",
@@ -473,6 +431,7 @@ def quantization_variable_spec(self):
473431
"outputs_grad_amax_history",
474432
],
475433
"gptq": [
434+
"bias",
476435
"quantized_kernel",
477436
"kernel_scale",
478437
"kernel_zero",

0 commit comments

Comments
 (0)