@@ -326,25 +326,25 @@ def save_own_variables(self, store):
326
326
if not self .built :
327
327
return
328
328
mode = self .quantization_mode
329
- if mode not in self .quantization_variable_spec :
329
+ if mode not in self .variable_serialization_spec :
330
330
raise self ._quantization_mode_error (mode )
331
331
332
332
# Kernel plus optional merged LoRA-aware scale (returns (kernel, None)
333
333
# for None/gptq)
334
334
kernel_value , merged_kernel_scale = self ._get_kernel_with_merged_lora ()
335
-
336
- # Save the variables using the name as the key.
337
- if mode != "gptq" :
338
- store ["kernel" ] = kernel_value
339
- if self .bias is not None :
340
- store ["bias" ] = self .bias
341
- for name in self .quantization_variable_spec [mode ]:
342
- if name == "kernel_scale" and mode in ("int4" , "int8" ):
335
+ idx = 0
336
+ for name in self .variable_serialization_spec [mode ]:
337
+ if name == "kernel" :
338
+ store [str (idx )] = kernel_value
339
+ elif name == "bias" and self .bias is None :
340
+ continue
341
+ elif name == "kernel_scale" and mode in ("int4" , "int8" ):
343
342
# For int4/int8, the merged LoRA scale (if any) comes from
344
343
# `_get_kernel_with_merged_lora()`
345
- store [name ] = merged_kernel_scale
344
+ store [str ( idx ) ] = merged_kernel_scale
346
345
else :
347
- store [name ] = getattr (self , name )
346
+ store [str (idx )] = getattr (self , name )
347
+ idx += 1
348
348
349
349
def load_own_variables (self , store ):
350
350
if not self .lora_enabled :
@@ -353,39 +353,18 @@ def load_own_variables(self, store):
353
353
if not self .built :
354
354
return
355
355
mode = self .quantization_mode
356
- if mode not in self .quantization_variable_spec :
356
+ if mode not in self .variable_serialization_spec :
357
357
raise self ._quantization_mode_error (mode )
358
358
359
- # Determine whether to use the legacy loading method.
360
- if "0" in store :
361
- return self ._legacy_load_own_variables (store )
362
-
363
- # Load the variables using the name as the key.
364
- if mode != "gptq" :
365
- self ._kernel .assign (store ["kernel" ])
366
- if self .bias is not None :
367
- self .bias .assign (store ["bias" ])
368
- for name in self .quantization_variable_spec [mode ]:
369
- getattr (self , name ).assign (store [name ])
370
- if self .lora_enabled :
371
- self .lora_kernel_a .assign (ops .zeros (self .lora_kernel_a .shape ))
372
- self .lora_kernel_b .assign (ops .zeros (self .lora_kernel_b .shape ))
373
-
374
- def _legacy_load_own_variables (self , store ):
375
- # The keys of the `store` will be saved as determined because the
376
- # default ordering will change after quantization
377
- mode = self .quantization_mode
378
- targets = []
379
- if mode != "gptq" :
380
- targets .append (self ._kernel )
381
- if self .bias is not None :
382
- targets .append (self .bias )
383
- targets .extend (
384
- getattr (self , name )
385
- for name in self .quantization_variable_spec [mode ]
386
- )
387
- for i , variable in enumerate (targets ):
388
- variable .assign (store [str (i )])
359
+ idx = 0
360
+ for name in self .variable_serialization_spec [mode ]:
361
+ if name == "kernel" :
362
+ self ._kernel .assign (store [str (idx )])
363
+ elif name == "bias" and self .bias is None :
364
+ continue
365
+ else :
366
+ getattr (self , name ).assign (store [str (idx )])
367
+ idx += 1
389
368
if self .lora_enabled :
390
369
self .lora_kernel_a .assign (ops .zeros (self .lora_kernel_a .shape ))
391
370
self .lora_kernel_b .assign (ops .zeros (self .lora_kernel_b .shape ))
@@ -418,53 +397,32 @@ def get_config(self):
418
397
config ["gptq_unpacked_column_size" ] = self .gptq_unpacked_column_size
419
398
return {** base_config , ** config }
420
399
421
- def _check_load_own_variables (self , store ):
422
- all_vars = self ._trainable_variables + self ._non_trainable_variables
423
- if len (store .keys ()) != len (all_vars ):
424
- if len (all_vars ) == 0 and not self .built :
425
- raise ValueError (
426
- f"Layer '{ self .name } ' was never built "
427
- "and thus it doesn't have any variables. "
428
- f"However the weights file lists { len (store .keys ())} "
429
- "variables for this layer.\n "
430
- "In most cases, this error indicates that either:\n \n "
431
- "1. The layer is owned by a parent layer that "
432
- "implements a `build()` method, but calling the "
433
- "parent's `build()` method did NOT create the state of "
434
- f"the child layer '{ self .name } '. A `build()` method "
435
- "must create ALL state for the layer, including "
436
- "the state of any children layers.\n \n "
437
- "2. You need to implement "
438
- "the `def build_from_config(self, config)` method "
439
- f"on layer '{ self .name } ', to specify how to rebuild "
440
- "it during loading. "
441
- "In this case, you might also want to implement the "
442
- "method that generates the build config at saving time, "
443
- "`def get_build_config(self)`. "
444
- "The method `build_from_config()` is meant "
445
- "to create the state "
446
- "of the layer (i.e. its variables) upon deserialization." ,
447
- )
448
- raise ValueError (
449
- f"Layer '{ self .name } ' expected { len (all_vars )} variables, "
450
- "but received "
451
- f"{ len (store .keys ())} variables during loading. "
452
- f"Expected: { [v .name for v in all_vars ]} "
453
- )
454
-
455
400
@property
456
- def quantization_variable_spec (self ):
457
- """Returns a dict mapping quantization modes to variable names.
401
+ def variable_serialization_spec (self ):
402
+ """Returns a dict mapping quantization modes to variable names in order .
458
403
459
404
This spec is used by `save_own_variables` and `load_own_variables` to
460
- determine which variables should be saved/loaded for each quantization
461
- mode.
405
+ determine the correct ordering of variables during serialization for
406
+ each quantization mode. `None` means no quantization .
462
407
"""
463
408
return {
464
- None : [],
465
- "int8" : ["kernel_scale" ],
466
- "int4" : ["kernel_scale" ],
409
+ None : [
410
+ "kernel" ,
411
+ "bias" ,
412
+ ],
413
+ "int8" : [
414
+ "kernel" ,
415
+ "bias" ,
416
+ "kernel_scale" ,
417
+ ],
418
+ "int4" : [
419
+ "kernel" ,
420
+ "bias" ,
421
+ "kernel_scale" ,
422
+ ],
467
423
"float8" : [
424
+ "kernel" ,
425
+ "bias" ,
468
426
"inputs_scale" ,
469
427
"inputs_amax_history" ,
470
428
"kernel_scale" ,
@@ -473,6 +431,7 @@ def quantization_variable_spec(self):
473
431
"outputs_grad_amax_history" ,
474
432
],
475
433
"gptq" : [
434
+ "bias" ,
476
435
"quantized_kernel" ,
477
436
"kernel_scale" ,
478
437
"kernel_zero" ,
0 commit comments