Skip to content

Commit 2985f90

Browse files
authored
support quant ckpt limit strategy (#9494)
* support quant ckpt limit strategy * bug fix * bug fix * fix bug * add log, fix bug
1 parent 756bce7 commit 2985f90

File tree

5 files changed

+68
-17
lines changed

5 files changed

+68
-17
lines changed

paddlenlp/quantization/unified_checkpoint_quantization.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import numpy as np
1516
import paddle
1617
from paddle.distributed import fleet
1718

@@ -33,7 +34,7 @@
3334
from paddlenlp.utils.log import logger
3435

3536

36-
def dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict):
37+
def dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict, use_pd=False):
3738
"""
3839
dequantize unified optimizer state dict.
3940
Args:
@@ -44,6 +45,7 @@ def dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict):
4445
scale_dict (`int`):
4546
compression checkpoint scale dict.
4647
"""
48+
logger.info(f"Start unified checkpoint dequantization, stage {ckpt_quant_stage}.")
4749
tp_rank, tp_degree = -1, 1
4850
if paddle.distributed.get_world_size() > 1:
4951
hcg = fleet.get_hybrid_communicate_group()
@@ -68,7 +70,7 @@ def dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict):
6870
dequant=True,
6971
tp_rank=tp_rank,
7072
tp_degree=tp_degree,
71-
use_pd=True,
73+
use_pd=use_pd,
7274
)
7375
state_dict[quant_key] = weight
7476
elif is_moment2:
@@ -85,10 +87,13 @@ def dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict):
8587
dequant=True,
8688
tp_rank=tp_rank,
8789
tp_degree=tp_degree,
88-
use_pd=True,
90+
use_pd=use_pd,
8991
)
9092
# cal m2
91-
weight = paddle.square(1.0 / weight - eps)
93+
if use_pd:
94+
weight = paddle.square(1.0 / weight - eps)
95+
else:
96+
weight = np.square(1.0 / weight - eps)
9297
state_dict[quant_key] = weight
9398
elif ckpt_quant_stage == "O2":
9499
# set eps
@@ -117,7 +122,7 @@ def dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict):
117122
quant=False,
118123
tp_rank=tp_rank,
119124
tp_degree=tp_degree,
120-
use_pd=True,
125+
use_pd=use_pd,
121126
symmetry=True,
122127
)
123128
ratio_weight = group_wise_quant_dequant(
@@ -128,14 +133,19 @@ def dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict):
128133
quant=False,
129134
tp_rank=tp_rank,
130135
tp_degree=tp_degree,
131-
use_pd=True,
136+
use_pd=use_pd,
132137
)
133138

134-
ratio_weight = paddle.square(1.0 / ratio_weight - eps)
139+
if use_pd:
140+
ratio_weight = paddle.square(1.0 / ratio_weight - eps)
141+
else:
142+
ratio_weight = np.square(1.0 / ratio_weight - eps)
135143
state_dict[quant_key] = ratio_weight
136144
m1_state_dict[quant_key[: -len(MOMENT2_KEYNAME)] + MOMENT1_KEYNAME] = m1_weight
137145
state_dict.update(m1_state_dict)
138146

147+
logger.info(f"Unified checkpoint dequantization done, stage {ckpt_quant_stage}.")
148+
139149
return state_dict
140150

141151

@@ -152,14 +162,15 @@ def quant_unified_optimizer(state_dict, state_dict_type, ckpt_quant_stage, async
152162
async_save (`bool`):
153163
whether use async_save.
154164
"""
165+
logger.info(f"Start unified checkpoint quantization, stage {ckpt_quant_stage}.")
166+
155167
quant = False
156168
if ckpt_quant_stage != "O0":
157169
quant = True
158170
del_key = []
159171
if quant and state_dict_type == "optimizer_weight":
160172
scales_dict = {}
161-
opt_keys = state_dict.keys()
162-
for k in opt_keys:
173+
for k in state_dict.keys():
163174
momentum1 = k.endswith(MOMENT1_KEYNAME)
164175
momentum2 = k.endswith(MOMENT2_KEYNAME)
165176

@@ -205,5 +216,6 @@ def quant_unified_optimizer(state_dict, state_dict_type, ckpt_quant_stage, async
205216
state_dict.pop(k, None)
206217

207218
state_dict.update(scales_dict)
219+
logger.info(f"Unified checkpoint quantization done, stage {ckpt_quant_stage}.")
208220

209221
return state_dict

paddlenlp/trainer/unified_checkpoint/unified_checkpoint.py

Lines changed: 34 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,10 @@
2929
unwrap_model,
3030
)
3131
from paddlenlp.transformers.utils import dtype_byte_size
32+
from paddlenlp.utils import infohub
3233
from paddlenlp.utils.env import (
3334
LORA_WEIGHTS_NAME,
35+
MAX_QUANTIZATION_TIMES,
3436
PADDLE_MASTER_WEIGHTS_NAME,
3537
PADDLE_OPTIMIZER_NAME,
3638
PADDLE_WEIGHTS_NAME,
@@ -239,9 +241,16 @@ def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, outp
239241
optimizer_name = _add_variant(SAFE_OPTIMIZER_NAME, self.args.optimizer_name_suffix)
240242
master_weights_name = _add_variant(SAFE_MASTER_WEIGHTS_NAME, self.args.optimizer_name_suffix)
241243

244+
sharded_optim_index = {}
242245
# save opt index json if checkpoint quantization is on.
243-
if self.args.ckpt_quant_stage != "O0":
244-
sharded_optim_index = {"ckpt_quant_stage": self.args.ckpt_quant_stage}
246+
if self.args.ckpt_quant_stage != "O0" and "quant_reach_limit" not in infohub:
247+
sharded_optim_index["ckpt_quant_stage"] = self.args.ckpt_quant_stage
248+
249+
sharded_optim_index["quant_ckpt_resume_times"] = (
250+
infohub["quant_ckpt_resume_times"] if "quant_ckpt_resume_times" in infohub else 0
251+
)
252+
253+
if len(sharded_optim_index) > 0:
245254
optimizer_index_name = SAFE_OPTIMIZER_INDEX_NAME
246255
path = os.path.join(output_dir, optimizer_index_name)
247256
if self.args.should_save:
@@ -257,7 +266,7 @@ def save_non_merge_optimizer(self, model, optim_state_dict, master_weights, outp
257266
signal_path=signal_dir,
258267
is_sync=is_sync_save,
259268
state_dict_type="optimizer_weight",
260-
ckpt_quant_stage=self.args.ckpt_quant_stage,
269+
ckpt_quant_stage=self.args.ckpt_quant_stage if "quant_reach_limit" not in infohub else "O0",
261270
)
262271
if master_weights is not None:
263272
self.async_handler._file_save_async_or_sync(
@@ -277,7 +286,7 @@ def load_non_merge_optimizer(self, model, optimizer, resume_from_checkpoint, ckp
277286
optimizer_path = os.path.join(resume_from_checkpoint, optimizer_name)
278287
master_weights_path = os.path.join(resume_from_checkpoint, master_weights_name)
279288
# no quantization & no master weight represent O1 AMP strategy.
280-
is_amp_o1 = True if not os.path.isfile(master_weights_path) and ckpt_quant_stage == "O0" else False
289+
is_amp_o1 = self.args.fp16_opt_level == "O1"
281290

282291
model_state_dict = get_expected_state_dict(model)
283292
struct2static_name_mappings = {k: v.name for k, v in model_state_dict.items()} # get optimizer param mappings
@@ -379,7 +388,7 @@ def save_unified_optimizer(self, model, optimizer, output_dir, signal_dir):
379388
signal_path=signal_dir,
380389
is_sync=is_sync_save,
381390
state_dict_type="optimizer_weight",
382-
ckpt_quant_stage=self.args.ckpt_quant_stage,
391+
ckpt_quant_stage=self.args.ckpt_quant_stage if "quant_reach_limit" not in infohub else "O0",
383392
)
384393
if master_weight_state_dict is not None:
385394
self.async_handler._file_save_async_or_sync(
@@ -429,10 +438,24 @@ def load_unified_optimizer(self, model, optimizer, resume_from_checkpoint):
429438
with open(os.path.join(resume_from_checkpoint, SAFE_OPTIMIZER_INDEX_NAME), "r") as f:
430439
index = json.loads(f.read())
431440

441+
# get quant ckpt info `ckpt_quant_stage` and `quant_ckpt_resume_times`
432442
ckpt_quant_stage = "O0"
433443
if "ckpt_quant_stage" in index:
434444
ckpt_quant_stage = index["ckpt_quant_stage"]
435445

446+
quant_ckpt_resume_times = 0
447+
if "quant_ckpt_resume_times" in index:
448+
quant_ckpt_resume_times = index["quant_ckpt_resume_times"]
449+
# increment and save resume times in infohub
450+
if ckpt_quant_stage != "O0":
451+
quant_ckpt_resume_times += 1
452+
infohub["quant_ckpt_resume_times"] = quant_ckpt_resume_times
453+
454+
# Quantization times exceeds the limit. Turn off the quantization strategy.
455+
if quant_ckpt_resume_times >= MAX_QUANTIZATION_TIMES:
456+
infohub["quant_reach_limit"] = True
457+
logger.info("Checkpoint quantization time reach limit and will be closed.")
458+
436459
# If not having merge optimizer, then load non-merge optimizer.
437460
if "weight_map" not in index:
438461
if self.args.data_parallel_rank == 0 or self.args.use_expert_parallel:
@@ -647,8 +670,12 @@ def unified_optimizer_into_shards(
647670
)
648671
sharded_optim_index = get_sharded_index(index_optimizer_filelist, total_optim_size_list)
649672

650-
if args.should_save and args.ckpt_quant_stage in ["O1", "O2"]:
651-
sharded_optim_index["ckpt_quant_stage"] = args.ckpt_quant_stage
673+
if args.should_save:
674+
if args.ckpt_quant_stage in ["O1", "O2"] and "quant_reach_limit" not in infohub:
675+
sharded_optim_index["ckpt_quant_stage"] = args.ckpt_quant_stage
676+
sharded_optim_index["quant_ckpt_resume_times"] = (
677+
infohub["quant_ckpt_resume_times"] if "quant_ckpt_resume_times" in infohub else 0
678+
)
652679

653680
if master_weights is not None:
654681
index_master_weight_filelist, total_master_weight_size_list = gather_sharded_object(

paddlenlp/transformers/model_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ def load_state_dict(
473473
if len(scale_dict) != 0:
474474
if ckpt_quant_stage == "O0":
475475
raise ValueError('optimizer weight has quantization scales but `ckpt_quant_stage` is set to "O0"')
476-
state_dict = dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict)
476+
state_dict = dequant_unified_optimizer(state_dict, ckpt_quant_stage, scale_dict, use_pd=True)
477477

478478
return state_dict
479479

paddlenlp/utils/env.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def _get_bool_env(env_key: str, default_value: str) -> bool:
123123
SYMMETRY_QUANT_SCALE = "@scales"
124124
ASYMMETRY_QUANT_SCALE_MIN = "@min_scales"
125125
ASYMMETRY_QUANT_SCALE_MAX = "@max_scales"
126+
MAX_QUANTIZATION_TIMES = 1
126127

127128
# LLM Inference related environment variables
128129
# Note(@Wanglongzhi2001): MAX_BSZ, SPECULATE_MAX_BSZ, MAX_DRAFT_TOKENS must be the same as definition in get_output / save_output

tests/llm/test_finetune.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16+
import json
17+
import os
1618
import sys
1719
import unittest
1820

1921
from parameterized import parameterized_class
2022

23+
from paddlenlp.utils.env import SAFE_OPTIMIZER_INDEX_NAME
2124
from tests.parallel_launch import TestMultipleGpus
2225
from tests.testing_utils import argv_context_guard, load_test_config
2326

@@ -92,8 +95,16 @@ def test_ckpt_quant(self):
9295
finetune_config["output_dir"] = self.output_dir
9396

9497
self.runfirst(finetune_config)
98+
99+
# get `quant_ckpt_resume_times`
100+
with open(os.path.join(self.output_dir, "checkpoint-1", SAFE_OPTIMIZER_INDEX_NAME), "r") as r:
101+
index = json.loads(r.read())
102+
quant_ckpt_resume_times = index["quant_ckpt_resume_times"]
103+
95104
self.rerun(finetune_config)
96105

106+
self.assertEqual(quant_ckpt_resume_times, 0)
107+
97108
def runfirst(self, train_args):
98109
self.run_n1c2(self.run_sft, **train_args)
99110

0 commit comments

Comments
 (0)