@@ -235,94 +235,65 @@ def save_own_variables(self, store):
235
235
# Do nothing if the layer isn't yet built
236
236
if not self .built :
237
237
return
238
- # The keys of the `store` will be saved as determined because the
239
- # default ordering will change after quantization
240
238
mode = self .quantization_mode
241
-
242
- # For int4/int8, the merged LoRA scale (if any) comes from
243
- # `_get_kernel_with_merged_lora()` and is appended below.
244
- MODE_SPEC = {
245
- None : [],
246
- "int4" : [],
247
- "int8" : [],
248
- "float8" : [
249
- "inputs_scale" ,
250
- "inputs_amax_history" ,
251
- "kernel_scale" ,
252
- "kernel_amax_history" ,
253
- "outputs_grad_scale" ,
254
- "outputs_grad_amax_history" ,
255
- ],
256
- "gptq" : [
257
- "quantized_kernel" ,
258
- "kernel_scale" ,
259
- "kernel_zero" ,
260
- "g_idx" ,
261
- ],
262
- }
263
-
264
- if mode not in MODE_SPEC :
239
+ if mode not in self .quantization_variable_spec :
265
240
raise self ._quantization_mode_error (mode )
266
241
267
242
# Kernel plus optional merged LoRA-aware scale (returns (kernel, None)
268
243
# for None/gptq)
269
244
kernel_value , merged_kernel_scale = self ._get_kernel_with_merged_lora ()
270
245
271
- targets = []
246
+ # Save the variables using the name as the key.
272
247
if mode != "gptq" :
273
- targets . append ( kernel_value )
248
+ store [ "kernel" ] = kernel_value
274
249
if self .bias is not None :
275
- targets .append (self .bias )
276
- if merged_kernel_scale is not None and mode in ("int4" , "int8" ):
277
- targets .append (merged_kernel_scale )
278
-
279
- # Append per-mode attributes (order matters)
280
- targets .extend (getattr (self , name ) for name in MODE_SPEC [mode ])
281
-
282
- for i , var in enumerate (targets ):
283
- store [str (i )] = var
250
+ store ["bias" ] = self .bias
251
+ for name in self .quantization_variable_spec [mode ]:
252
+ if name == "kernel_scale" and mode in ("int4" , "int8" ):
253
+ # For int4/int8, the merged LoRA scale (if any) comes from
254
+ # `_get_kernel_with_merged_lora()`
255
+ store [name ] = merged_kernel_scale
256
+ else :
257
+ store [name ] = getattr (self , name )
284
258
285
259
def load_own_variables (self , store ):
286
260
if not self .lora_enabled :
287
261
self ._check_load_own_variables (store )
288
262
# Do nothing if the layer isn't yet built
289
263
if not self .built :
290
264
return
291
- # The keys of the `store` will be saved as determined because the
292
- # default ordering will change after quantization
293
265
mode = self .quantization_mode
266
+ if mode not in self .quantization_variable_spec :
267
+ raise self ._quantization_mode_error (mode )
294
268
295
- # Per-mode variable spec (order matters).
296
- MODE_SPEC = {
297
- None : [],
298
- "int8" : ["kernel_scale" ],
299
- "int4" : ["kernel_scale" ],
300
- "float8" : [
301
- "inputs_scale" ,
302
- "inputs_amax_history" ,
303
- "kernel_scale" ,
304
- "kernel_amax_history" ,
305
- "outputs_grad_scale" ,
306
- "outputs_grad_amax_history" ,
307
- ],
308
- "gptq" : [
309
- "quantized_kernel" ,
310
- "kernel_scale" ,
311
- "kernel_zero" ,
312
- "g_idx" ,
313
- ],
314
- }
269
+ # Determine whether to use the legacy loading method.
270
+ if "0" in store :
271
+ return self ._legacy_load_own_variables (store )
315
272
316
- if mode not in MODE_SPEC :
317
- raise self ._quantization_mode_error (mode )
273
+ # Load the variables using the name as the key.
274
+ if mode != "gptq" :
275
+ self ._kernel .assign (store ["kernel" ])
276
+ if self .bias is not None :
277
+ self .bias .assign (store ["bias" ])
278
+ for name in self .quantization_variable_spec [mode ]:
279
+ getattr (self , name ).assign (store [name ])
280
+ if self .lora_enabled :
281
+ self .lora_kernel_a .assign (ops .zeros (self .lora_kernel_a .shape ))
282
+ self .lora_kernel_b .assign (ops .zeros (self .lora_kernel_b .shape ))
318
283
284
+ def _legacy_load_own_variables (self , store ):
285
+ # The keys of the `store` will be saved as determined because the
286
+ # default ordering will change after quantization
287
+ mode = self .quantization_mode
319
288
targets = []
320
289
if mode != "gptq" :
321
290
targets .append (self ._kernel )
322
291
if self .bias is not None :
323
292
targets .append (self .bias )
324
- targets .extend (getattr (self , name ) for name in MODE_SPEC [mode ])
325
-
293
+ targets .extend (
294
+ getattr (self , name )
295
+ for name in self .quantization_variable_spec [mode ]
296
+ )
326
297
for i , variable in enumerate (targets ):
327
298
variable .assign (store [str (i )])
328
299
if self .lora_enabled :
@@ -385,6 +356,34 @@ def _check_load_own_variables(self, store):
385
356
f"Expected: { [v .name for v in all_vars ]} "
386
357
)
387
358
359
+ @property
360
+ def quantization_variable_spec (self ):
361
+ """Returns a dict mapping quantization modes to variable names.
362
+
363
+ This spec is used by `save_own_variables` and `load_own_variables` to
364
+ determine which variables should be saved/loaded for each quantization
365
+ mode.
366
+ """
367
+ return {
368
+ None : [],
369
+ "int8" : ["kernel_scale" ],
370
+ "int4" : ["kernel_scale" ],
371
+ "float8" : [
372
+ "inputs_scale" ,
373
+ "inputs_amax_history" ,
374
+ "kernel_scale" ,
375
+ "kernel_amax_history" ,
376
+ "outputs_grad_scale" ,
377
+ "outputs_grad_amax_history" ,
378
+ ],
379
+ "gptq" : [
380
+ "quantized_kernel" ,
381
+ "kernel_scale" ,
382
+ "kernel_zero" ,
383
+ "g_idx" ,
384
+ ],
385
+ }
386
+
388
387
def quantized_build (self , kernel_shape , mode , config = None ):
389
388
if mode == "int8" :
390
389
self ._int8_build (kernel_shape )
0 commit comments