29
29
unwrap_model ,
30
30
)
31
31
from paddlenlp .transformers .utils import dtype_byte_size
32
+ from paddlenlp .utils import infohub
32
33
from paddlenlp .utils .env import (
33
34
LORA_WEIGHTS_NAME ,
35
+ MAX_QUANTIZATION_TIMES ,
34
36
PADDLE_MASTER_WEIGHTS_NAME ,
35
37
PADDLE_OPTIMIZER_NAME ,
36
38
PADDLE_WEIGHTS_NAME ,
@@ -239,9 +241,16 @@ def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, outp
239
241
optimizer_name = _add_variant (SAFE_OPTIMIZER_NAME , self .args .optimizer_name_suffix )
240
242
master_weights_name = _add_variant (SAFE_MASTER_WEIGHTS_NAME , self .args .optimizer_name_suffix )
241
243
244
+ sharded_optim_index = {}
242
245
# save opt index json if checkpoint quantization is on.
243
- if self .args .ckpt_quant_stage != "O0" :
244
- sharded_optim_index = {"ckpt_quant_stage" : self .args .ckpt_quant_stage }
246
+ if self .args .ckpt_quant_stage != "O0" and "quant_reach_limit" not in infohub :
247
+ sharded_optim_index ["ckpt_quant_stage" ] = self .args .ckpt_quant_stage
248
+
249
+ sharded_optim_index ["quant_ckpt_resume_times" ] = (
250
+ infohub ["quant_ckpt_resume_times" ] if "quant_ckpt_resume_times" in infohub else 0
251
+ )
252
+
253
+ if len (sharded_optim_index ) > 0 :
245
254
optimizer_index_name = SAFE_OPTIMIZER_INDEX_NAME
246
255
path = os .path .join (output_dir , optimizer_index_name )
247
256
if self .args .should_save :
@@ -257,7 +266,7 @@ def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, outp
257
266
signal_path = signal_dir ,
258
267
is_sync = is_sync_save ,
259
268
state_dict_type = "optimizer_weight" ,
260
- ckpt_quant_stage = self .args .ckpt_quant_stage ,
269
+ ckpt_quant_stage = self .args .ckpt_quant_stage if "quant_reach_limit" not in infohub else "O0" ,
261
270
)
262
271
if master_weights is not None :
263
272
self .async_handler ._file_save_async_or_sync (
@@ -277,7 +286,7 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint, ckp
277
286
optimizer_path = os .path .join (resume_from_checkpoint , optimizer_name )
278
287
master_weights_path = os .path .join (resume_from_checkpoint , master_weights_name )
279
288
# no quantization & no master weight represent O1 AMP strategy.
280
- is_amp_o1 = True if not os . path . isfile ( master_weights_path ) and ckpt_quant_stage == "O0" else False
289
+ is_amp_o1 = self . args . fp16_opt_level == "O1"
281
290
282
291
model_state_dict = get_expected_state_dict (model )
283
292
struct2static_name_mappings = {k : v .name for k , v in model_state_dict .items ()} # get optimizer param mappings
@@ -379,7 +388,7 @@ def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir):
379
388
signal_path = signal_dir ,
380
389
is_sync = is_sync_save ,
381
390
state_dict_type = "optimizer_weight" ,
382
- ckpt_quant_stage = self .args .ckpt_quant_stage ,
391
+ ckpt_quant_stage = self .args .ckpt_quant_stage if "quant_reach_limit" not in infohub else "O0" ,
383
392
)
384
393
if master_weight_state_dict is not None :
385
394
self .async_handler ._file_save_async_or_sync (
@@ -429,10 +438,24 @@ def load_unified_optimizer(self, model, optimizer, resume_from_checkpoint):
429
438
with open (os .path .join (resume_from_checkpoint , SAFE_OPTIMIZER_INDEX_NAME ), "r" ) as f :
430
439
index = json .loads (f .read ())
431
440
441
+ # get quant ckpt info `ckpt_quant_stage` and `quant_ckpt_resume_times`
432
442
ckpt_quant_stage = "O0"
433
443
if "ckpt_quant_stage" in index :
434
444
ckpt_quant_stage = index ["ckpt_quant_stage" ]
435
445
446
+ quant_ckpt_resume_times = 0
447
+ if "quant_ckpt_resume_times" in index :
448
+ quant_ckpt_resume_times = index ["quant_ckpt_resume_times" ]
449
+ # increment and save resume times in infohub
450
+ if ckpt_quant_stage != "O0" :
451
+ quant_ckpt_resume_times += 1
452
+ infohub ["quant_ckpt_resume_times" ] = quant_ckpt_resume_times
453
+
454
+ # Quantization times exceeds the limit. Turn off the quantization strategy.
455
+ if quant_ckpt_resume_times >= MAX_QUANTIZATION_TIMES :
456
+ infohub ["quant_reach_limit" ] = True
457
+ logger .info ("Checkpoint quantization time reach limit and will be closed." )
458
+
436
459
# If not having merge optimizer, then load non-merge optimizer.
437
460
if "weight_map" not in index :
438
461
if self .args .data_parallel_rank == 0 or self .args .use_expert_parallel :
@@ -647,8 +670,12 @@ def unified_optimizer_into_shards(
647
670
)
648
671
sharded_optim_index = get_sharded_index (index_optimizer_filelist , total_optim_size_list )
649
672
650
- if args .should_save and args .ckpt_quant_stage in ["O1" , "O2" ]:
651
- sharded_optim_index ["ckpt_quant_stage" ] = args .ckpt_quant_stage
673
+ if args .should_save :
674
+ if args .ckpt_quant_stage in ["O1" , "O2" ] and "quant_reach_limit" not in infohub :
675
+ sharded_optim_index ["ckpt_quant_stage" ] = args .ckpt_quant_stage
676
+ sharded_optim_index ["quant_ckpt_resume_times" ] = (
677
+ infohub ["quant_ckpt_resume_times" ] if "quant_ckpt_resume_times" in infohub else 0
678
+ )
652
679
653
680
if master_weights is not None :
654
681
index_master_weight_filelist , total_master_weight_size_list = gather_sharded_object (
0 commit comments