Skip to content

Commit 98bfddd

Browse files
authored
[llm]fix bug in ChatGLM merge tp params and prefix model generation (#6730)
* fix * fix prefix generation * fix chatglm tp
1 parent 8c01eeb commit 98bfddd

File tree

7 files changed

+93
-78
lines changed

7 files changed

+93
-78
lines changed

llm/README.md

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -167,18 +167,12 @@ python -u -m paddle.distributed.launch --gpus "0,1" finetune_generation.py ./
167167
我们使用张量并行(TP,Tensor Parallelism)训练过程中,为了节省TP参数合并时间往往在中间checkpoint将参数存储为多个TP参数分片,可以使用提供的分片合并参数脚本进行参数合并。
168168

169169
```
170-
python merge_tp_params.py \
171-
--model_name_or_path ./checkpoints/chatglm_v2_sft_ckpts/checkpoint-7163 \
172-
--merge_model_path ./checkpoints/chatglm_v2_sft_ckpts/checkpoint_merge \
173-
--dtype "float16" \
174-
--with_tokenizer
170+
python merge_tp_params.py \
171+
--model_name_or_path ./checkpoints/chatglm_v2_sft_ckpts/checkpoint-100
175172
```
176173

177174
**参数:**
178-
- `model_name_or_path`: 必须,预训练模型名称或者本地的模型路径,用于热启模型和分词器,默认为None。
179-
- `merge_model_path`: 必须,合并参数后保存路径,默认为None。
180-
- `dtype`: 必须,模型参数dtype,默认为None。
181-
- `with_tokenizer`: 是否同时保存分词器,默认为False。
175+
- `model_name_or_path`: 必须,本地的TP模型参数路径,默认为None。
182176
- `device`: 运行环境,默认为gpu。
183177

184178
### 3.7 LoRA参数合并

llm/merge_tp_params.py

Lines changed: 69 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,52 +11,90 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import importlib
15+
import os
16+
1417
import paddle
15-
from paddle.distributed import fleet
1618

17-
from paddlenlp.transformers import AutoModelForCausalLM, AutoTokenizer
19+
from paddlenlp.transformers import AutoConfig
20+
from paddlenlp.transformers.auto.modeling import MAPPING_NAMES
21+
from paddlenlp.utils.log import logger
1822

1923

2024
def parse_arguments():
2125
import argparse
2226

2327
parser = argparse.ArgumentParser()
2428
parser.add_argument("--model_name_or_path", default=None, required=True, help="The directory of model.")
25-
parser.add_argument("--merge_model_path", default=None, required=True, help="The directory of merged model.")
2629
parser.add_argument("--device", type=str, default="gpu", help="Device")
27-
parser.add_argument("--dtype", type=str, default=None, required=True, help="Model dtype")
28-
parser.add_argument("--with_tokenizer", type=bool, default=True, help="Save tokenizer at the same time")
2930
return parser.parse_args()
3031

3132

32-
def merge():
33+
def load_tp_params(tp_degree, path):
34+
tp_state_dict_list = []
35+
for tp in range(tp_degree):
36+
tp_state_dict = {}
37+
tmp = paddle.load(os.path.join(path, f"model_state.tp{tp:0>2d}.pdparams"), return_numpy=True)
38+
for k, v in tmp.items():
39+
tp_state_dict[k] = v
40+
tp_state_dict_list.append(tp_state_dict)
41+
42+
return tp_state_dict_list
43+
44+
45+
def merge_tensor_parallel(model_class, state_dict_list, config) -> None:
46+
"""the entry of converting config and converting model file
47+
48+
Args:
49+
input_dir (str | None): the input dir which contains `pytorch_model.bin` and `config.json` file
50+
config (PretrainedConfig): the PretrainedConfig instance of model
51+
"""
52+
name_action_mappings = model_class._get_tensor_parallel_mappings(config, is_split=False)
53+
state_keys_map = model_class._resolve_prefix_keys(name_action_mappings.keys(), state_dict_list[0].keys())
54+
55+
for k, v in state_keys_map.items():
56+
name_action_mappings[v] = name_action_mappings.pop(k)
57+
58+
state_dict_to_save = {}
59+
for key in state_dict_list[0].keys():
60+
tensor = state_dict_list[0][key]
61+
if key in name_action_mappings:
62+
ret = [x[key] for x in state_dict_list]
63+
action = name_action_mappings.pop(key)
64+
tensor = action(ret)
65+
66+
state_dict_to_save[key] = tensor
67+
68+
if len(name_action_mappings) > 0:
69+
for x in name_action_mappings.keys():
70+
logger.warning(f"key <{x}> need to merge tensor parallel but we can't find in model state.")
71+
72+
logger.info("Finally, we merging state dict to fellowing tensors.")
73+
for k, v in state_dict_to_save.items():
74+
logger.info(f"{k}, {v.shape}, {v.dtype}")
75+
76+
return state_dict_to_save
77+
78+
79+
def main():
3380
args = parse_arguments()
3481
paddle.set_device(args.device)
35-
tensor_parallel_degree = paddle.distributed.get_world_size()
36-
tensor_parallel_rank = 0
37-
if tensor_parallel_degree > 1:
38-
strategy = fleet.DistributedStrategy()
39-
strategy.hybrid_configs = {
40-
"dp_degree": 1,
41-
"mp_degree": tensor_parallel_degree,
42-
"pp_degree": 1,
43-
"sharding_degree": 1,
44-
}
45-
fleet.init(is_collective=True, strategy=strategy)
46-
hcg = fleet.get_hybrid_communicate_group()
47-
tensor_parallel_rank = hcg.get_model_parallel_rank()
48-
49-
model = AutoModelForCausalLM.from_pretrained(
50-
args.model_name_or_path,
51-
dtype=args.dtype,
52-
tensor_parallel_degree=tensor_parallel_degree,
53-
tensor_parallel_rank=tensor_parallel_rank,
54-
)
55-
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
56-
if tensor_parallel_rank == 0:
57-
model.save_pretrained(args.merge_model_path, merge_tensor_parallel=tensor_parallel_degree > 1)
58-
tokenizer.save_pretrained(args.merge_model_path)
82+
config = AutoConfig.from_pretrained(args.model_name_or_path)
83+
init_class = config["architectures"][0]
84+
import_class = importlib.import_module(f"paddlenlp.transformers.{MAPPING_NAMES[init_class[:-11]]}.modeling")
85+
model_class = getattr(import_class, init_class)
86+
87+
if config.tensor_parallel_degree > 1:
88+
tp_state_dict_list = load_tp_params(config.tensor_parallel_degree, args.model_name_or_path)
89+
state_dict_to_save = merge_tensor_parallel(
90+
model_class=model_class, state_dict_list=tp_state_dict_list, config=config
91+
)
92+
93+
logger.info("Saving")
94+
paddle.save(state_dict_to_save, os.path.join(args.model_name_or_path, "model_state.pdparams"))
95+
else:
96+
logger.info("No need to merge since config.tensor_parallel_degree <= 1.")
5997

6098

6199
if __name__ == "__main__":
62-
merge()
100+
main()

llm/opt/pt_argument.json

Lines changed: 0 additions & 30 deletions
This file was deleted.

llm/predict_generation.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ def __init__(self, args):
116116
model=self.model,
117117
prefix_path=self.args.prefix_path,
118118
postprocess_past_key_value=prefix_tuning_params["postprocess_past_key_value"],
119-
pad_attention_mask=prefix_tuning_params["pad_attention_mask"],
120119
)
121120
self.model.eval()
122121
self.tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path, padding_side="left")

llm/quant.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def apply_gptq(quant_args, trainer, ptq_dataloader):
167167
for cur_name, cur_layer in model.named_sublayers():
168168
if type(cur_layer) in [paddle.nn.Linear, ColumnParallelLinear, RowParallelLinear]:
169169
num_layer += 1
170-
logger.info("GPTQ layer", num_layer, cur_name)
170+
logger.info(f"GPTQ layer: {num_layer}, {cur_name}")
171171
parent_layer, sub_name = find_parent_layer_and_sub_name(model, cur_name)
172172
cur_quant_layer = GPTQ(cur_layer)
173173
setattr(parent_layer, sub_name, cur_quant_layer)

paddlenlp/peft/prefix/prefix_model.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -121,15 +121,30 @@ def generate(self, **kwargs):
121121
def _prepare_inputs_for_generation(self, *args, **kwargs):
122122
model_kwargs = self.model_prepare_inputs_for_generation(*args, **kwargs)
123123
attention_mask = model_kwargs["attention_mask"]
124+
batch_size = model_kwargs["input_ids"].shape[0]
124125
if self.pad_attention_mask is not None:
125126
attention_mask = self.pad_attention_mask(
126127
model_kwargs["input_ids"].shape, self.prefix_config.num_prefix_tokens, attention_mask
127128
)
128129
else:
129-
prefix_attention_mask = paddle.ones(
130-
[model_kwargs["input_ids"].shape[0], self.prefix_config.num_prefix_tokens], dtype=attention_mask.dtype
131-
)
132-
attention_mask = paddle.concat((prefix_attention_mask, attention_mask), axis=1)
130+
if len(attention_mask.shape) == 2:
131+
prefix_attention_mask = paddle.ones(
132+
[batch_size, self.prefix_config.num_prefix_tokens], dtype=attention_mask.dtype
133+
)
134+
elif len(attention_mask.shape) == 3:
135+
batch_size, src_seq_len, tgt_seq_len = attention_mask.shape
136+
prefix_attention_mask = paddle.ones(
137+
[batch_size, src_seq_len, self.prefix_config.num_prefix_tokens], dtype=attention_mask.dtype
138+
)
139+
elif len(attention_mask.shape) == 4:
140+
batch_size, num_heads, src_seq_len, tgt_seq_len = attention_mask.shape
141+
prefix_attention_mask = paddle.ones(
142+
[batch_size, num_heads, src_seq_len, self.prefix_config.num_prefix_tokens],
143+
dtype=attention_mask.dtype,
144+
)
145+
else:
146+
raise ValueError(f"Unexpected attention_mask shape: {attention_mask.shape}")
147+
attention_mask = paddle.concat((prefix_attention_mask, attention_mask), axis=-1)
133148
model_kwargs["attention_mask"] = attention_mask
134149

135150
if "past_key_values" in self.forward_keys:
@@ -139,7 +154,6 @@ def _prepare_inputs_for_generation(self, *args, **kwargs):
139154
else:
140155
raise NotImplementedError("Model does not support past_key_values either cache")
141156
if model_kwargs[key] is None:
142-
batch_size = model_kwargs["input_ids"].shape[0]
143157
past_key_values = self._get_past_key_values(batch_size)
144158
model_kwargs[key] = past_key_values
145159
return model_kwargs

paddlenlp/transformers/chatglm/modeling.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -751,7 +751,7 @@ def forward(self, hidden_states):
751751

752752

753753
class ChatGLMForCausalLM(ChatGLMPretrainedModel):
754-
_keys_to_ignore_on_save = [r"lm_head.weight"]
754+
_keys_to_ignore_on_save = [r"lm_head.decoder_weight"]
755755
_tied_weights_keys = ["lm_head.weight"]
756756

757757
def __init__(self, config: ChatGLMConfig):

0 commit comments

Comments
 (0)