Skip to content

Commit 99c5910

Browse files
committed
Implement huggingface checkpoint loading
1 parent a1b067e commit 99c5910

File tree

1 file changed

+194
-14
lines changed

1 file changed

+194
-14
lines changed

examples/pre-training/ernie/pretrain.py

Lines changed: 194 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@
3535
PdArgumentParser,
3636
get_last_checkpoint,
3737
)
38+
from paddleformers.trainer.unified_checkpoint import unified_checkpoint
39+
from paddleformers.transformers.model_utils import unwrap_model
40+
41+
from safetensors import safe_open
3842

3943
try:
4044
from paddleformers.utils.downloader import get_static_model_on_pdc
@@ -202,6 +206,190 @@ def _collate_data(data, stack_fn=Stack()):
202206
return train_dataset, valid_dataset, test_dataset, _collate_data
203207

204208

209+
def load_huggingface_checkpoint(model, args):
210+
fused_rms_norm_replace = [
211+
("self_attn.fused_rms_norm_linear.rms_norm_weight", "input_layernorm.weight"),
212+
("self_attn.fused_rms_norm_linear.linear_weight", "self_attn.qkv_proj.weight"),
213+
]
214+
shared_layers_prefix = "shared_layers.embed_weight_share."
215+
unnamed_layers = ["ernie.norm.weight", "lm_head.weight"]
216+
217+
logger.info(f"Loading huggingface checkpoint from {args.model_name_or_path}")
218+
with open(
219+
os.path.join(args.model_name_or_path, "model.safetensors.index.json")
220+
) as f:
221+
weight_map = json.load(f)["weight_map"]
222+
223+
ep_degree = fleet.get_hybrid_communicate_group().get_expert_parallel_world_size()
224+
ep_rank = fleet.get_hybrid_communicate_group().get_expert_parallel_rank()
225+
expert_offset = (model.config.moe_num_experts // ep_degree) * ep_rank
226+
use_torch_format = False
227+
228+
def param_to_weight(name):
229+
# for PP=1, we only need to substitute the fused_rms_norm and expert_id
230+
for src, dst in fused_rms_norm_replace:
231+
name = name.replace(src, dst)
232+
if m := re.search(r"mlp\.experts\.(\d+)", name):
233+
expert_id = expert_offset + int(m.group(1))
234+
s, e = m.span()
235+
name = name[:s] + f"mlp.experts.{expert_id}" + name[e:]
236+
if isinstance(model, ErnieMoEForCausalLM):
237+
return name
238+
239+
# for PP>1, we also need to handle special layers and adjust layer_idx
240+
if name.startswith(shared_layers_prefix):
241+
return "ernie." + name[len(shared_layers_prefix) :]
242+
layer_idx, stem = name.split(".", maxsplit=1)
243+
if stem == "weight":
244+
return unnamed_layers.pop(0)
245+
if stem.startswith("mtp"):
246+
return f"ernie.{stem}"
247+
return f"ernie.layers.{int(layer_idx) - 1}.{stem}"
248+
249+
def try_torch_format(weight_key):
250+
if weight_key.startswith("ernie."):
251+
weight_key = "model." + weight_key[6:]
252+
253+
key_decompose = [weight_key]
254+
if ".up_gate_proj." in weight_key:
255+
key_decompose = [
256+
weight_key.replace(".up_gate_proj.", ".gate_proj."),
257+
weight_key.replace(".up_gate_proj.", ".up_proj."),
258+
]
259+
elif ".qkv_proj." in weight_key:
260+
key_decompose = [
261+
weight_key.replace(".qkv_proj.", ".q_proj."),
262+
weight_key.replace(".qkv_proj.", ".k_proj."),
263+
weight_key.replace(".qkv_proj.", ".v_proj."),
264+
]
265+
266+
tensor_decompose = []
267+
for key in key_decompose:
268+
if not (weight_file := weight_map.get(key)):
269+
return None
270+
with safe_open(
271+
os.path.join(args.model_name_or_path, weight_file),
272+
framework="numpy",
273+
) as f:
274+
tensor = paddle.to_tensor(f.get_tensor(key))
275+
if "_proj." in key or ".gate." in key:
276+
tensor = tensor.T.contiguous()
277+
tensor_decompose.append(tensor)
278+
279+
if len(tensor_decompose) == 1:
280+
return tensor_decompose[0]
281+
else:
282+
return paddle.concat(tensor_decompose, axis=-1)
283+
284+
def auto_fix_shape(param, weight):
285+
assert len(param.shape) == len(weight.shape), "rank not match"
286+
assert all(
287+
p_dim <= w_dim for p_dim, w_dim in zip(param.shape, weight.shape)
288+
), "weight too small"
289+
indices = tuple(slice(0, dim) for dim in param.shape)
290+
return weight[indices].contiguous()
291+
292+
for name, param in model.named_parameters():
293+
weight_key = param_to_weight(name)
294+
if weight_file := weight_map.get(weight_key):
295+
with safe_open(
296+
os.path.join(args.model_name_or_path, weight_file),
297+
framework="numpy",
298+
) as f:
299+
weight = paddle.to_tensor(f.get_tensor(weight_key))
300+
elif (weight := try_torch_format(weight_key)) is not None:
301+
use_torch_format = True
302+
else:
303+
logger.warning(
304+
f"param `{name}`'s weight `{weight_key}` not found. "
305+
"Skip initializing."
306+
)
307+
continue
308+
if use_torch_format and "lm_head" in weight_key:
309+
weight = weight.T.contiguous()
310+
if param.shape != weight.shape:
311+
logger.warning(
312+
f"param `{name}`'s shape doesn't match weight `{weight_key}`: "
313+
f"{param.shape} and {weight.shape}. Auto fixing."
314+
)
315+
weight = auto_fix_shape(param, weight)
316+
param.copy_(weight)
317+
318+
319+
def get_expected_state_dict(model, **kwargs):
320+
fused_rms_norm_replace = [
321+
("self_attn.fused_rms_norm_linear.rms_norm_weight", "input_layernorm.weight"),
322+
("self_attn.fused_rms_norm_linear.linear_weight", "self_attn.qkv_proj.weight"),
323+
]
324+
shared_layers_prefix = "shared_layers.embed_weight_share."
325+
unnamed_layers = ["ernie.norm.weight", "lm_head.weight"]
326+
327+
model = unwrap_model(model)
328+
hcg = fleet.get_hybrid_communicate_group()
329+
ep_degree = hcg.get_expert_parallel_world_size()
330+
ep_rank = hcg.get_expert_parallel_rank()
331+
expert_offset = (model.config.moe_num_experts // ep_degree) * ep_rank
332+
333+
if model.config.head_dim is None:
334+
head_dim = model.config.hidden_size // model.config.num_attention_heads
335+
else:
336+
head_dim = model.config.head_dim
337+
q_dim = head_dim * model.config.num_attention_heads
338+
kv_dim = head_dim * model.config.num_key_value_heads
339+
340+
def copy_attr(out, param):
341+
if hasattr(param, "is_distributed"):
342+
out.is_distributed = param.is_distributed
343+
if hasattr(param, "no_sync"):
344+
out.no_sync = param.no_sync
345+
return out
346+
347+
def param_to_weight(name):
348+
# for PP=1, we only need to substitute the fused_rms_norm and expert_id
349+
for src, dst in fused_rms_norm_replace:
350+
name = name.replace(src, dst)
351+
if m := re.search(r"\.experts\.(\d+)\.", name):
352+
expert_id = expert_offset + int(m.group(1))
353+
s, e = m.span()
354+
name = name[:s] + f".experts.{expert_id}." + name[e:]
355+
if isinstance(model, ErnieMoEForCausalLM):
356+
return name
357+
358+
# for PP>1, we also need to handle special layers and adjust layer_idx
359+
if name.startswith(shared_layers_prefix):
360+
return "ernie." + name[len(shared_layers_prefix) :]
361+
layer_idx, stem = name.split(".", maxsplit=1)
362+
if stem == "weight":
363+
return unnamed_layers.pop(0)
364+
if stem.startswith("mtp"):
365+
return f"ernie.{stem}"
366+
return f"ernie.layers.{int(layer_idx) - 1}.{stem}"
367+
368+
state_dict = {}
369+
for name, param in model.state_dict().items():
370+
name = param_to_weight(name)
371+
if name.startswith("ernie."):
372+
name = "model." + name[6:]
373+
374+
if "_proj." in name or ".gate." in name or "lm_head" in name:
375+
param = copy_attr(param.T, param)
376+
377+
if ".up_gate_proj." in name:
378+
gate, up = param.split(2)
379+
gate, up = copy_attr(gate, param), copy_attr(up, param)
380+
state_dict[name.replace(".up_gate_proj.", ".gate_proj.")] = gate
381+
state_dict[name.replace(".up_gate_proj.", ".up_proj.")] = up
382+
elif ".qkv_proj." in name:
383+
assert q_dim + kv_dim * 2 == param.shape[0]
384+
state_dict[name.replace(".qkv_proj.", ".q_proj.")] = param[:q_dim]
385+
state_dict[name.replace(".qkv_proj.", ".k_proj.")] = param[q_dim:-kv_dim]
386+
state_dict[name.replace(".qkv_proj.", ".v_proj.")] = param[-kv_dim:]
387+
else:
388+
state_dict[name] = param
389+
390+
return state_dict
391+
392+
205393
def main():
206394
if set_affinity is not None:
207395
set_affinity_code = set_affinity()
@@ -520,21 +708,12 @@ def sname_to_tname(pp_model):
520708
cfg.enable_delay_scale_loss = args.enable_delay_scale_loss
521709
register_pp_reshard_information(cfg.num_hidden_layers)
522710

523-
if args.from_scratch:
524-
model = ErnieMoEForCausalLMPipe(cfg)
525-
else:
526-
model = ErnieMoEForCausalLMPipe.from_pretrained(
527-
args.model_name_or_path,
528-
config=cfg,
529-
)
711+
model = ErnieMoEForCausalLMPipe(cfg)
530712
else:
531-
if args.from_scratch:
532-
model = ErnieMoEForCausalLM(cfg)
533-
else:
534-
model = ErnieMoEForCausalLM.from_pretrained(
535-
args.model_name_or_path,
536-
config=cfg,
537-
)
713+
model = ErnieMoEForCausalLM(cfg)
714+
715+
if not args.from_scratch:
716+
load_huggingface_checkpoint(model, args)
538717

539718
cfg = model.config
540719
logger.info(f"using model type:{type(model)}")
@@ -581,6 +760,7 @@ def sname_to_tname(pp_model):
581760
if args.do_train:
582761
train_result = trainer.train(resume_from_checkpoint=checkpoint)
583762
metrics = train_result.metrics
763+
unified_checkpoint.get_expected_state_dict = get_expected_state_dict
584764
trainer.save_model(args.output_dir)
585765
trainer.log_metrics("train", metrics)
586766
trainer.save_metrics("train", metrics)

0 commit comments

Comments
 (0)