-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathconvert_ckpt_deepseek3.py
More file actions
926 lines (786 loc) · 49.9 KB
/
convert_ckpt_deepseek3.py
File metadata and controls
926 lines (786 loc) · 49.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (c) Huawei Technologies Co., Ltd. 2024-2025. All rights reserved.
import argparse
import json
import logging as logger
import os
from collections import defaultdict
import numpy as np
import safetensors
import torch
import safetensors.torch
import bitsandbytes as bnb
logger.basicConfig(format="")
logger.getLogger().setLevel(logger.INFO)
HIDDEN_SIZE = 7168
NUM_EXPERTS = 256
FIRST_K_DENSE_REPLACE = 3
MTP_LAYER_INDEX = 61
NUM_ATTENTION_HEADS = 128
QK_HEAD_DIM = 128
QK_POS_EMB_HEAD_DIM = 64
V_HEAD_DIM = 128
class CkptConvert(object):
"""
Converts a HuggingFace checkpoint to Megatron format.
Args:
hf_model_path (str): HuggingFace model path.
mg_save_path (str): Megatron model save path.
num_layers (int): Number of transformer layers.
tp_size (int, optional): Degree of tensor model parallelism. Defaults to 1.
pp_size (int, optional): Degree of pipeline model parallelism. Defaults to 1.
ep_size (int, optional): Degree of expert model parallelism. Defaults to 1.
vpp_stage (int, optional): The stage number in the virtual pipeline parallelism. Defaults to None.
num_dense_layers (int, optional): The number of first k dense layers. Defaults to 3.
num_layer_list (str, optional): Specifies the number of parallel pipeline layers. If None, all blocks have the same number of layers. Defaults to None.
noop_layers (str, optional): should be skipped during conversion. Defaults to None.
moe_grouped_gemm (bool, optional): Whether to use grouped GEMM for MoE layers.
moe_tp_extend_ep (bool, optional): Whether to use tp group to extend experts parallism.
mla_mm_split (bool, optional): Whether to split up-proj in MLA.
dualpipe (bool, optional): Whether to use dualpipe.
mtp_num_layers (int, optional): The number of MTP layers. Defaults to 0.
qlora_nf4 (bool, optional): Whether to use QLORA NF4. Defaults to False.
"""
def __init__(
self,
hf_model_path: str,
mg_save_path: str,
num_layers: int,
tp_size: int = 1,
pp_size: int = 1,
ep_size: int = 1,
num_dense_layers: int = 3,
num_layer_list: str = None,
noop_layers: str = None,
vpp_stage: int = None,
moe_grouped_gemm: bool = False,
moe_tp_extend_ep: bool = False,
mla_mm_split: bool = False,
dualpipe: bool = False,
mtp_num_layers: int = 0,
qlora_nf4: bool = False,
):
self.tp_size = tp_size
self.pp_size = pp_size
self.ep_size = ep_size
self.num_layers = num_layers
self.vpp_stage = vpp_stage
if vpp_stage is not None:
self.vpp_size = self.num_layers // self.pp_size // self.vpp_stage
self.hf_model_path = hf_model_path
self.mg_save_path = mg_save_path
self.num_layer_list = num_layer_list
self.noop_layers = noop_layers
self.moe_grouped_gemm = moe_grouped_gemm
self.moe_tp_extend_ep = moe_tp_extend_ep
self.mla_mm_split = mla_mm_split
self.dualpipe = True if dualpipe == "dualpipev" else False
self.first_k_dense_replace = num_dense_layers
self.mtp_num_layers = mtp_num_layers
if not os.path.exists(self.hf_model_path):
raise FileNotFoundError(f"Model path does not exist: {self.hf_model_path}")
if dualpipe:
if vpp_stage:
raise ValueError(f"dualpipe is not compatible with virtual pipeline parallel.")
self.vpp_size = 2
self.vpp_stage = self.num_layers // self.pp_size // self.vpp_size
self.hidden_size = HIDDEN_SIZE
self.num_experts = NUM_EXPERTS
self.num_attention_heads = NUM_ATTENTION_HEADS
self.qk_head_dim = QK_HEAD_DIM
self.qk_pos_emb_head_dim = QK_POS_EMB_HEAD_DIM
self.v_head_dim = V_HEAD_DIM
self.mtp_layer_number = MTP_LAYER_INDEX
self.qlora_nf4 = qlora_nf4
self._valid_parameter()
if self.vpp_stage is None:
self.pprank_layer_idxs = defaultdict()
self.get_pprank_hf_layeridxs()
else:
self.vpprank_layer_idxs = defaultdict(dict)
self.get_vpprank_hf_layeridxs()
@staticmethod
def qlora_nf4_weight(weight):
"""Quantize weights"""
quantweight = bnb.nn.Params4bit(
weight,
requires_grad=weight.requires_grad,
quant_type="nf4"
).to('npu').cpu()
return quantweight.data, quantweight.quant_state
def qlora_nf4_quant(self, mg_model, ep_rank, tp_rank, key, weight):
"""Save quant state"""
quant_data, quant_state = self.qlora_nf4_weight(weight)
mg_model[ep_rank][tp_rank][key] = quant_data
for k, v in quant_state.as_dict(packed=True).items():
mg_model[ep_rank][tp_rank]["{}.{}".format(key, k)] = v.detach()
@staticmethod
def load_hf_model(file_path):
"""Load safetensors file"""
logger.info(f"Loading the checkpoint from {file_path}.")
return safetensors.torch.load_file(file_path)
@staticmethod
def mg_path_process(mg_path):
"""megatron model path"""
iter_mg_path = os.path.join(mg_path, "iter_0000001")
if not os.path.exists(mg_path):
os.makedirs(mg_path, exist_ok=True)
with open(os.path.join(mg_path, "latest_checkpointed_iteration.txt"), 'w') as f:
f.write("1")
return iter_mg_path
def generate_mg_weights_dir(self, tp_rank, pp_rank, ep_rank):
"""Generate the megatron weight directory."""
if self.ep_size == 1 and self.pp_size == 1:
prefix = f"mp_rank_{tp_rank:02}"
elif self.ep_size == 1:
prefix = f"mp_rank_{tp_rank:02}_{pp_rank:03}"
elif self.pp_size == 1:
prefix = f"mp_rank_{tp_rank:02}_{ep_rank:03}"
else:
prefix = f"mp_rank_{tp_rank:02}_{pp_rank:03}_{ep_rank:03}"
return prefix
def _valid_parameter(self):
if self.first_k_dense_replace > FIRST_K_DENSE_REPLACE:
raise ValueError("first_k_dense_replace should be less than 3")
if self.dualpipe:
if self.tp_size > 1 and not self.moe_tp_extend_ep:
raise ValueError("When dualpipe is enabled, moe-tp-extend-ep should be used at the same time.")
if self.num_layer_list is None:
if self.num_layers % self.pp_size != 0:
raise ValueError("number of layers should be divisible by the pipeline parallel size")
if self.vpp_stage is not None:
if (self.num_layers % self.pp_size) % self.vpp_stage != 0:
raise ValueError("number of pp_stage should be divisible by the vpp_stage")
else:
layer_list = list(map(int, self.num_layer_list.split(',')))
if self.vpp_stage is not None:
raise ValueError("num_layer_list and vpp cannot be configured at the same time")
if len(layer_list) != self.pp_size:
raise ValueError("number of layer_list should be equal to pipeline parallel size")
if sum(layer_list) != self.num_layers:
raise ValueError("sum of layer_list should be equal to num_layers")
if self.noop_layers is not None:
raise ValueError("num_layer_list and noop_layers cannot be configured at the same time")
if self.num_layers != 61:
raise ValueError("num_layer_list supports only full parameters")
def get_layer_files_map(self):
"""layer -> safetensors file map"""
layer_map_dict = defaultdict(set)
weights_map_file_path = os.path.join(self.hf_model_path, "model.safetensors.index.json")
with open(weights_map_file_path) as f:
weights_map = json.load(f)
weights_map = weights_map["weight_map"]
for key, value in weights_map.items():
if key.startswith("model.layers."):
layer_name = int(key.split('model.layers.')[1].split('.')[0])
layer_map_dict[layer_name].add(value)
else:
layer_map_dict[key].add(value)
return layer_map_dict
def get_pprank_hf_layeridxs(self) -> None:
"""pp_rank -> hf layer map"""
num_noop_layers = 0 if self.noop_layers is None else len(list(map(int, self.noop_layers.split(","))))
num_real_layers = self.num_layers - num_noop_layers
num_layer_list_ = [i for i in range(num_real_layers)]
# Specifies the number of dense layers.
if self.first_k_dense_replace < FIRST_K_DENSE_REPLACE:
num_real_layers = self.num_layers - num_noop_layers
num_moe_layers = num_real_layers - self.first_k_dense_replace
num_layer_list_ = [i for i in range(self.first_k_dense_replace)] + [i + 3 for i in range(num_moe_layers)]
if self.num_layer_list is None:
layers_each_pp = [self.num_layers // self.pp_size] * self.pp_size
if self.noop_layers is not None:
for layer in list(map(int, self.noop_layers.split(","))):
cur_pp_rank = layer // (self.num_layers // self.pp_size)
layers_each_pp[cur_pp_rank] -= 1
else:
layers_each_pp = list(map(int, self.num_layer_list.split(',')))
for pp_rank in range(self.pp_size):
self.pprank_layer_idxs[pp_rank] = [num_layer_list_.pop(0) for _ in range(layers_each_pp[pp_rank])]
# mtp layer
if self.mtp_num_layers:
nextn_layer_list = [self.mtp_layer_number + i for i in range(self.mtp_num_layers)]
self.pprank_layer_idxs[self.pp_size - 1].extend(nextn_layer_list)
def get_vpprank_hf_layeridxs(self) -> None:
"""vpp_rank -> hf layer map"""
num_noop_layers = 0 if self.noop_layers is None else len(list(map(int, self.noop_layers.split(","))))
num_real_layers = self.num_layers - num_noop_layers
num_layer_list_ = [i for i in range(num_real_layers)]
if self.first_k_dense_replace < FIRST_K_DENSE_REPLACE:
num_real_layers = self.num_layers - num_noop_layers
num_moe_layers = num_real_layers - self.first_k_dense_replace
num_layer_list_ = [i for i in range(self.first_k_dense_replace)] + [i + 3 for i in range(num_moe_layers)]
if not self.dualpipe:
if self.vpp_stage is not None:
layers_each_vpp = [[self.vpp_stage] * self.vpp_size for _ in range(self.pp_size)]
# examples: num_layers8,pp2,vpp_stage2 [[0 1, 4 5], [2 3, 6 7]]
# no noop layer --> layers_each_vpp:[[2,2], [2,2]]
# noop4,5 --> layers_each_vpp:[[2,0], [2,2]]
if self.noop_layers is not None:
for layer in list(map(int, self.noop_layers.split(","))):
vpp_idx = layer // self.vpp_stage // self.pp_size
pp_idx = layer % (self.pp_size * self.vpp_stage) // self.vpp_stage
layers_each_vpp[pp_idx][vpp_idx] -= 1
for vpp_rank in range(self.vpp_size):
for pp_rank in range(self.pp_size):
self.vpprank_layer_idxs[pp_rank][vpp_rank] = [num_layer_list_.pop(0) for _ in
range(layers_each_vpp[pp_rank][vpp_rank])]
else:
noop_layers_list = None if not self.noop_layers else np.array(
sorted(list(map(int, self.noop_layers.split(",")))))
min_noop_layer = None if not self.noop_layers else noop_layers_list[0]
dualpipe_layer_list = []
layers_each_pp = self.num_layers // self.pp_size
layer_pop_num = layers_each_pp // 2
all_layer_list = [i for i in range(self.num_layers)]
# dualpipe_layer_list example
# pp2: [0 1 2 3 4 5 6 7] -> [0 1 6 7 | 2 3 4 5]
# pp4: [0 1 2 3 4 5 6 7] -> [0 7 | 1 6 | 2 5 | 3 4]
while all_layer_list:
dualpipe_layer_list.extend(all_layer_list[:layer_pop_num])
dualpipe_layer_list.extend(all_layer_list[-layer_pop_num:])
all_layer_list = all_layer_list[layer_pop_num:-layer_pop_num]
# calc pp idx and vpp idx of each hf layer
pp_rank, vpp_rank = 0, 0
each_pp_layer = self.num_layers // self.pp_size
for idx, layer in enumerate(dualpipe_layer_list):
if vpp_rank not in self.vpprank_layer_idxs[pp_rank]:
self.vpprank_layer_idxs[pp_rank][vpp_rank] = []
if not self.noop_layers:
self.vpprank_layer_idxs[pp_rank][vpp_rank].append(layer)
else:
# ignore noop layer
if layer in noop_layers_list:
if (idx + 1) % self.vpp_stage == 0:
vpp_rank += 1
if (idx + 1) % each_pp_layer == 0:
pp_rank += 1
vpp_rank = 0
continue
if layer < min_noop_layer:
self.vpprank_layer_idxs[pp_rank][vpp_rank].append(layer)
if layer > min_noop_layer:
# remove noop layer index
before_nums = sum(noop_layers_list < layer)
self.vpprank_layer_idxs[pp_rank][vpp_rank].append(layer - before_nums)
# update vpp_rank
if (idx + 1) % self.vpp_stage == 0:
vpp_rank += 1
# update pp_rank, reset vpp_rank
if (idx + 1) % each_pp_layer == 0:
pp_rank += 1
vpp_rank = 0
if self.mtp_num_layers:
nextn_layer_list = [self.mtp_layer_number + i for i in range(self.mtp_num_layers)]
# for dualpipe, mtp layer in pp0vpp1
mtp_pp_rank = 0 if self.dualpipe else self.pp_size - 1
self.vpprank_layer_idxs[mtp_pp_rank][self.vpp_size - 1].extend(nextn_layer_list)
def load_matched_hf_weights(self, pp_rank, vpp_rank=None):
"""Read the safetensors file corresponding to the layer of pp_rank."""
if vpp_rank is None:
layer_list = self.pprank_layer_idxs[pp_rank]
else:
layer_list = self.vpprank_layer_idxs[pp_rank][vpp_rank].copy()
if pp_rank == self.pp_size - 1 and self.mtp_num_layers:
nextn_layer_list = [self.mtp_layer_number + i for i in range(self.mtp_num_layers)]
layer_list.extend(nextn_layer_list)
layer_files_map_dict = self.get_layer_files_map()
st_filename_list = []
for layer in layer_list:
# start with model.layers.[layer_number], contains the mtp layer.
st_filename_list.extend(list(layer_files_map_dict[layer]))
if pp_rank == 0:
st_filename_list.extend(list(layer_files_map_dict["model.embed_tokens.weight"]))
if self.dualpipe:
st_filename_list.extend(list(layer_files_map_dict["lm_head.weight"]))
st_filename_list.extend(list(layer_files_map_dict["model.norm.weight"]))
if pp_rank == self.pp_size - 1 and not self.dualpipe:
st_filename_list.extend(list(layer_files_map_dict["model.norm.weight"]))
st_filename_list.extend(list(layer_files_map_dict["lm_head.weight"]))
st_filename_list = list(set(st_filename_list))
st_filename_list.sort()
all_pp_weights = {}
for filename in st_filename_list:
cur_weights = self.load_hf_model(os.path.join(self.hf_model_path, filename))
all_pp_weights.update(cur_weights)
return all_pp_weights
def set_model_preprocess(self, weights_dict, mg_model):
"""Embedding layer process"""
emb_weight = weights_dict.pop("model.embed_tokens.weight")
for ep_rank in range(self.ep_size):
emb_weight_lst = torch.chunk(emb_weight, self.tp_size, dim=0)
for tp_rank in range(self.tp_size):
mg_model[ep_rank][tp_rank]["embedding.word_embeddings.weight"] = emb_weight_lst[tp_rank].clone()
def set_model_postprocess(self, weights_dict, mg_model):
"""Final norm & LM Head process"""
final_norm = weights_dict.pop("model.norm.weight")
lm_head = weights_dict.pop("lm_head.weight")
for ep_rank in range(self.ep_size):
lm_head_lst = torch.chunk(lm_head, self.tp_size, dim=0)
for tp_rank in range(self.tp_size):
if self.mtp_num_layers:
mg_model[ep_rank][tp_rank]["final_layernorm.weight"] = final_norm.clone()
else:
mg_model[ep_rank][tp_rank]["decoder.final_layernorm.weight"] = final_norm.clone()
mg_model[ep_rank][tp_rank]["output_layer.weight"] = lm_head_lst[tp_rank].clone()
if self.qlora_nf4:
self.qlora_nf4_quant(mg_model, ep_rank, tp_rank, "output_layer.weight",
lm_head_lst[tp_rank].clone())
def set_mtp_preprocess(self, hf_layer_idx, mtp_layer_idx, weights_dict, mg_model):
"""MTP layer preprocess"""
enorm_weight = weights_dict.pop(f"model.layers.{hf_layer_idx}.enorm.weight")
hnorm_weight = weights_dict.pop(f"model.layers.{hf_layer_idx}.hnorm.weight")
eh_proj_weight = weights_dict.pop(f"model.layers.{hf_layer_idx}.eh_proj.weight")
emb_weight = weights_dict.pop(f"model.layers.{hf_layer_idx}.embed_tokens.weight")
for ep_rank in range(self.ep_size):
eh_proj_lst = torch.chunk(eh_proj_weight, self.tp_size, dim=0)
emb_lst = torch.chunk(emb_weight, self.tp_size, dim=0)
for tp_rank in range(self.tp_size):
mg_model[ep_rank][tp_rank][f"mtp.layers.{mtp_layer_idx}.enorm.weight"] = enorm_weight.clone()
mg_model[ep_rank][tp_rank][f"mtp.layers.{mtp_layer_idx}.hnorm.weight"] = hnorm_weight.clone()
mg_model[ep_rank][tp_rank][f"mtp.layers.{mtp_layer_idx}.eh_proj.weight"] = eh_proj_lst[tp_rank].clone()
if self.qlora_nf4:
self.qlora_nf4_quant(mg_model, ep_rank, tp_rank, f"mtp.layers.{mtp_layer_idx}.eh_proj.weight",
eh_proj_lst[tp_rank].clone())
if self.pp_size > 1:
mg_model[ep_rank][tp_rank][f"embedding.word_embeddings.weight"] = \
emb_lst[tp_rank].clone()
def set_mtp_postprocess(self, hf_layer_idx, mtp_layer_idx, weights_dict, mg_model):
"""MTP layer postprocess"""
mtp_norm_weight = weights_dict.pop(f"model.layers.{hf_layer_idx}.shared_head.norm.weight")
for ep_rank in range(self.ep_size):
for tp_rank in range(self.tp_size):
mg_model[ep_rank][tp_rank][
f"mtp.final_layernorms.{mtp_layer_idx}.weight"] = mtp_norm_weight.clone()
def set_model_layer_norm(self, hf_layer_idx, local_layer_idx, weights_dict, mg_model, mtp_layer_flag=False):
"""Layernorm process"""
input_norm = weights_dict.pop(f"model.layers.{hf_layer_idx}.input_layernorm.weight")
post_attn_norm = weights_dict.pop(f"model.layers.{hf_layer_idx}.post_attention_layernorm.weight")
input_norm_key = f"decoder.layers.{local_layer_idx}.input_layernorm.weight"
post_norm_key = f"decoder.layers.{local_layer_idx}.pre_mlp_layernorm.weight"
# Weight key of the mtp layer is different from that of the transformers layer.
if mtp_layer_flag:
input_norm_key = f"mtp.layers.{local_layer_idx}.transformer_layer.input_layernorm.weight"
post_norm_key = f"mtp.layers.{local_layer_idx}.transformer_layer.pre_mlp_layernorm.weight"
for ep_rank in range(self.ep_size):
for tp_rank in range(self.tp_size):
mg_model[ep_rank][tp_rank][input_norm_key] = input_norm.clone()
mg_model[ep_rank][tp_rank][post_norm_key] = post_attn_norm.clone()
def set_model_layer_attn(self, hf_layer, local_layer_idx, weights_dict, mg_model, mtp_layer_flag=False):
"""Attention layer process"""
def _generate_attn_layers_key(mtp_flag, local_idx):
prefix = f"mtp.layers.{local_idx}.transformer_layer" if mtp_flag else \
f"decoder.layers.{local_idx}"
qkv_key = f"{prefix}.self_attention.linear_qkv.weight"
dense_key = f"{prefix}.self_attention.linear_proj.weight"
q_layernorm_key = f"{prefix}.self_attention.q_layernorm.weight"
kv_layernorm_key = f"{prefix}.self_attention.kv_layernorm.weight"
q_b_key = f"{prefix}.self_attention.linear_q_up_proj.weight"
kv_b_key = f"{prefix}.self_attention.linear_kv_up_proj.weight"
return qkv_key, dense_key, q_layernorm_key, kv_layernorm_key, q_b_key, kv_b_key
def _generate_attn_mm_split_key(mtp_flag, local_idx):
prefix = f"mtp.layers.{local_idx}.transformer_layer" if mtp_flag else \
f"decoder.layers.{local_idx}"
qk_nope_key = f"{prefix}.self_attention.linear_qk_nope.weight"
qk_rope_key = f"{prefix}.self_attention.linear_qk_rope.weight"
kv_nope_key = f"{prefix}.self_attention.linear_kv_nope.weight"
linear_v_key = f"{prefix}.self_attention.linear_v.weight"
return qk_nope_key, qk_rope_key, kv_nope_key, linear_v_key
hf_q_proj = weights_dict.pop(f"model.layers.{hf_layer}.self_attn.q_a_proj.weight")
hf_kv_proj = weights_dict.pop(f"model.layers.{hf_layer}.self_attn.kv_a_proj_with_mqa.weight")
qkv_weight = torch.cat([hf_q_proj.reshape((-1, self.hidden_size)),
hf_kv_proj.reshape((-1, self.hidden_size))], dim=0)
dense_weight = weights_dict.pop(f"model.layers.{hf_layer}.self_attn.o_proj.weight")
q_layernorm = weights_dict.pop(f"model.layers.{hf_layer}.self_attn.q_a_layernorm.weight")
kv_layernorm = weights_dict.pop(f"model.layers.{hf_layer}.self_attn.kv_a_layernorm.weight")
q_b_proj = weights_dict.pop(f"model.layers.{hf_layer}.self_attn.q_b_proj.weight")
kv_b_proj = weights_dict.pop(f"model.layers.{hf_layer}.self_attn.kv_b_proj.weight")
qkv_key, dense_key, q_layernorm_key, kv_layernorm_key, q_b_key, kv_b_key = _generate_attn_layers_key(
mtp_layer_flag, local_layer_idx)
if self.mla_mm_split:
qk_nope_key, qk_rope_key, kv_nope_key, linear_v_key = _generate_attn_mm_split_key(mtp_layer_flag,
local_layer_idx)
q_b_proj = q_b_proj.reshape(self.num_attention_heads, (self.qk_head_dim + self.qk_pos_emb_head_dim), -1)
kv_b_proj = kv_b_proj.reshape(self.num_attention_heads, (self.qk_head_dim + self.v_head_dim), -1)
qk_nope, qk_rope = torch.split(q_b_proj, [self.qk_head_dim, self.qk_pos_emb_head_dim], dim=1)
kv_nope, linear_v = torch.split(kv_b_proj, [self.qk_head_dim, self.v_head_dim], dim=1)
qk_nope = qk_nope.reshape(self.num_attention_heads * self.qk_head_dim, -1)
qk_rope = qk_rope.reshape(self.num_attention_heads * self.qk_pos_emb_head_dim, -1)
kv_nope = kv_nope.reshape(self.num_attention_heads * self.qk_head_dim, -1)
linear_v = linear_v.reshape(self.num_attention_heads * self.v_head_dim, -1)
for ep_rank in range(self.ep_size):
dense_lst = torch.chunk(dense_weight, self.tp_size, dim=1)
if self.mla_mm_split:
qk_nope_lst = torch.chunk(qk_nope, self.tp_size, dim=0)
qk_rope_lst = torch.chunk(qk_rope, self.tp_size, dim=0)
kv_nope_lst = torch.chunk(kv_nope, self.tp_size, dim=0)
linear_v_lst = torch.chunk(linear_v, self.tp_size, dim=0)
else:
linear_qb_lst = torch.chunk(q_b_proj, self.tp_size, dim=0)
linear_kvb_lst = torch.chunk(kv_b_proj, self.tp_size, dim=0)
for tp_rank in range(self.tp_size):
mg_model[ep_rank][tp_rank][qkv_key] = qkv_weight.clone()
mg_model[ep_rank][tp_rank][dense_key] = dense_lst[tp_rank].clone()
mg_model[ep_rank][tp_rank][q_layernorm_key] = q_layernorm.clone()
mg_model[ep_rank][tp_rank][kv_layernorm_key] = kv_layernorm.clone()
if self.qlora_nf4:
self.qlora_nf4_quant(mg_model, ep_rank, tp_rank, qkv_key, qkv_weight.clone())
self.qlora_nf4_quant(mg_model, ep_rank, tp_rank, dense_key, dense_lst[tp_rank].clone())
if self.mla_mm_split:
mg_model[ep_rank][tp_rank][qk_nope_key] = qk_nope_lst[tp_rank].clone()
mg_model[ep_rank][tp_rank][qk_rope_key] = qk_rope_lst[tp_rank].clone()
mg_model[ep_rank][tp_rank][kv_nope_key] = kv_nope_lst[tp_rank].clone()
mg_model[ep_rank][tp_rank][linear_v_key] = linear_v_lst[tp_rank].clone()
if self.qlora_nf4:
self.qlora_nf4_quant(mg_model, ep_rank, tp_rank, qk_nope_key, qk_nope_lst[tp_rank].clone())
self.qlora_nf4_quant(mg_model, ep_rank, tp_rank, qk_rope_key, qk_rope_lst[tp_rank].clone())
self.qlora_nf4_quant(mg_model, ep_rank, tp_rank, kv_nope_key, kv_nope_lst[tp_rank].clone())
self.qlora_nf4_quant(mg_model, ep_rank, tp_rank, linear_v_key, linear_v_lst[tp_rank].clone())
else:
mg_model[ep_rank][tp_rank][q_b_key] = linear_qb_lst[tp_rank].clone()
mg_model[ep_rank][tp_rank][kv_b_key] = linear_kvb_lst[tp_rank].clone()
if self.qlora_nf4:
self.qlora_nf4_quant(mg_model, ep_rank, tp_rank, q_b_key, linear_qb_lst[tp_rank].clone())
self.qlora_nf4_quant(mg_model, ep_rank, tp_rank, kv_b_key, linear_kvb_lst[tp_rank].clone())
def set_model_layer_mlp(self, hf_layer_idx, local_layer_idx, weights_dict, mg_model, mtp_layer_flag=False):
"""MLP layer process"""
def _generate_moe_layer_key(local_idx, mtp_flag):
prefix = f"mtp.layers.{local_idx}.transformer_layer" if mtp_flag else f"decoder.layers.{local_layer_idx}"
router_key = f"{prefix}.mlp.router.weight"
router_bias_key = f"{prefix}.mlp.router.expert_bias"
shared_fc1_key = f"{prefix}.mlp.shared_experts.linear_fc1.weight"
shared_fc2_key = f"{prefix}.mlp.shared_experts.linear_fc2.weight"
experts_weight1_key = f"{prefix}.mlp.experts.weight1"
experts_weight2_key = f"{prefix}.mlp.experts.weight2"
return router_key, router_bias_key, shared_fc1_key, shared_fc2_key, experts_weight1_key, experts_weight2_key
if hf_layer_idx < self.first_k_dense_replace:
# dense layer
gate_proj = weights_dict.pop(f"model.layers.{hf_layer_idx}.mlp.gate_proj.weight")
up_proj = weights_dict.pop(f"model.layers.{hf_layer_idx}.mlp.up_proj.weight")
linear_fc1_weight = torch.cat([gate_proj, up_proj], dim=0)
linear_fc2_weight = weights_dict.pop(f"model.layers.{hf_layer_idx}.mlp.down_proj.weight")
for ep_rank in range(self.ep_size):
gate, up = torch.chunk(linear_fc1_weight, 2, dim=0)
mlp_l0_weight_W = torch.chunk(gate, self.tp_size, dim=0)
mlp_l0_weight_V = torch.chunk(up, self.tp_size, dim=0)
mlp_l0_weight = [torch.cat(weights, dim=0) for weights in zip(mlp_l0_weight_W, mlp_l0_weight_V)]
mlp_l1_weight = torch.chunk(linear_fc2_weight, self.tp_size, dim=1)
for tp_rank in range(self.tp_size):
mg_model[ep_rank][tp_rank][f"decoder.layers.{local_layer_idx}.mlp.linear_fc1.weight"] = \
mlp_l0_weight[tp_rank].clone()
mg_model[ep_rank][tp_rank][f"decoder.layers.{local_layer_idx}.mlp.linear_fc2.weight"] = \
mlp_l1_weight[tp_rank].clone()
if self.qlora_nf4:
self.qlora_nf4_quant(mg_model, ep_rank, tp_rank,
f"decoder.layers.{local_layer_idx}.mlp.linear_fc1.weight",
mlp_l0_weight[tp_rank].clone())
self.qlora_nf4_quant(mg_model, ep_rank, tp_rank,
f"decoder.layers.{local_layer_idx}.mlp.linear_fc2.weight",
mlp_l1_weight[tp_rank].clone())
else:
# moe layer & mtp layer
mlp_router_weight = weights_dict.pop(f"model.layers.{hf_layer_idx}.mlp.gate.weight")
mlp_router_weight = mlp_router_weight[:self.num_experts, :]
mlp_router_bias = weights_dict.pop(f"model.layers.{hf_layer_idx}.mlp.gate.e_score_correction_bias")
mlp_router_bias = mlp_router_bias[:self.num_experts]
shared_gate_proj = weights_dict.pop(f"model.layers.{hf_layer_idx}.mlp.shared_experts.gate_proj.weight")
shared_up_proj = weights_dict.pop(f"model.layers.{hf_layer_idx}.mlp.shared_experts.up_proj.weight")
shared_fc2_weight = weights_dict.pop(f"model.layers.{hf_layer_idx}.mlp.shared_experts.down_proj.weight")
experts_linear_fc1_list = []
experts_linear_fc2_list = []
for expert_idx in range(self.num_experts):
shared_l0_W = torch.chunk(shared_gate_proj, self.tp_size, dim=0)
shared_l0_V = torch.chunk(shared_up_proj, self.tp_size, dim=0)
shared_l0_lst = [torch.cat(weights, dim=0) for weights in zip(shared_l0_W, shared_l0_V)]
shared_l1_lst = torch.chunk(shared_fc2_weight, self.tp_size, dim=1)
gate_proj = weights_dict.pop(f"model.layers.{hf_layer_idx}.mlp.experts.{expert_idx}.gate_proj.weight")
up_proj = weights_dict.pop(f"model.layers.{hf_layer_idx}.mlp.experts.{expert_idx}.up_proj.weight")
expert_tp_size = self.tp_size
if self.moe_tp_extend_ep:
expert_tp_size = 1
gate_w_list = torch.chunk(gate_proj, expert_tp_size, dim=0)
up_w_list = torch.chunk(up_proj, expert_tp_size, dim=0)
fc1_weight = torch.cat([torch.cat(weights, dim=0) for weights in zip(gate_w_list, up_w_list)], dim=0)
fc2_weight = weights_dict.pop(f"model.layers.{hf_layer_idx}.mlp.experts.{expert_idx}.down_proj.weight")
experts_linear_fc1_list.append(fc1_weight.t())
experts_linear_fc2_list.append(fc2_weight.t())
# generate weights key
router_key, router_bias_key, shared_fc1_key, shared_fc2_key, experts_weight1_key, experts_weight2_key = _generate_moe_layer_key(
local_layer_idx, mtp_layer_flag)
for ep_rank in range(self.ep_size):
for tp_rank in range(self.tp_size):
mg_model[ep_rank][tp_rank][router_key] = mlp_router_weight.clone()
mg_model[ep_rank][tp_rank][router_bias_key] = mlp_router_bias.clone()
mg_model[ep_rank][tp_rank][shared_fc1_key] = shared_l0_lst[tp_rank].clone()
mg_model[ep_rank][tp_rank][shared_fc2_key] = shared_l1_lst[tp_rank].clone()
if self.qlora_nf4:
self.qlora_nf4_quant(mg_model, ep_rank, tp_rank, shared_fc1_key, shared_l0_lst[tp_rank].clone())
self.qlora_nf4_quant(mg_model, ep_rank, tp_rank, shared_fc2_key, shared_l1_lst[tp_rank].clone())
if self.moe_grouped_gemm:
gemm_fc1 = torch.cat(experts_linear_fc1_list).view(self.hidden_size, -1)
gemm_fc2 = torch.cat(experts_linear_fc2_list).view(-1, self.hidden_size)
if self.moe_tp_extend_ep:
gemm_fc1_ep = torch.chunk(gemm_fc1.view(self.num_experts, self.hidden_size, -1),
self.ep_size * self.tp_size, dim=0)
gemm_fc2_ep = torch.chunk(gemm_fc2.view(self.num_experts, -1, self.hidden_size),
self.ep_size * self.tp_size, dim=0)
else:
gemm_fc1_ep = torch.chunk(gemm_fc1.view(self.num_experts, self.hidden_size, -1), self.ep_size,
dim=0)
gemm_fc2_ep = torch.chunk(gemm_fc2.view(self.num_experts, -1, self.hidden_size), self.ep_size,
dim=0)
for ep_rank in range(self.ep_size):
if not self.moe_tp_extend_ep:
gemm_fc1_ep_tp = torch.chunk(gemm_fc1_ep[ep_rank], self.tp_size, dim=2)
gemm_fc2_ep_tp = torch.chunk(gemm_fc2_ep[ep_rank], self.tp_size, dim=1)
for tp_rank in range(self.tp_size):
if self.moe_tp_extend_ep:
mg_model[ep_rank][tp_rank][experts_weight1_key] = gemm_fc1_ep[
ep_rank * self.tp_size + tp_rank].reshape(self.hidden_size, -1).clone()
mg_model[ep_rank][tp_rank][experts_weight2_key] = gemm_fc2_ep[
ep_rank * self.tp_size + tp_rank].reshape(-1, self.hidden_size).clone()
if self.qlora_nf4:
self.qlora_nf4_quant(mg_model, ep_rank, tp_rank, experts_weight1_key,
gemm_fc1_ep[ep_rank * self.tp_size + tp_rank].reshape(self.hidden_size, -1).clone())
self.qlora_nf4_quant(mg_model, ep_rank, tp_rank, experts_weight2_key,
gemm_fc2_ep[ep_rank * self.tp_size + tp_rank].reshape(-1, self.hidden_size).clone())
else:
mg_model[ep_rank][tp_rank][experts_weight1_key] = gemm_fc1_ep_tp[tp_rank].reshape(
self.hidden_size, -1).clone()
mg_model[ep_rank][tp_rank][experts_weight2_key] = gemm_fc2_ep_tp[tp_rank].reshape(
-1, self.hidden_size).clone()
if self.qlora_nf4:
self.qlora_nf4_quant(mg_model, ep_rank, tp_rank, experts_weight1_key,
gemm_fc1_ep_tp[tp_rank].reshape(self.hidden_size, -1).clone())
self.qlora_nf4_quant(mg_model, ep_rank, tp_rank, experts_weight2_key,
gemm_fc2_ep_tp[tp_rank].reshape(-1, self.hidden_size).clone())
else:
num_local_experts = self.num_experts // self.ep_size
for ep_rank in range(self.ep_size):
for local_experts_idx in range(num_local_experts):
local_prefix = f"decoder.layers.{local_layer_idx}.mlp.experts.local_experts"
local_fc1_key = f"{local_prefix}.{local_experts_idx}.linear_fc1.weight"
local_fc2_key = f"{local_prefix}.{local_experts_idx}.linear_fc2.weight"
if mtp_layer_flag:
local_prefix = f"mtp.layers.{local_layer_idx}.transformer_layer.mlp.experts.local_experts"
local_fc1_key = f"{local_prefix}.{local_experts_idx}.linear_fc1.weight"
local_fc2_key = f"{local_prefix}.{local_experts_idx}.linear_fc2.weight"
global_experts_idx = local_experts_idx + ep_rank * num_local_experts
local_fc1_weight = experts_linear_fc1_list[global_experts_idx].t()
local_fc2_weight = experts_linear_fc2_list[global_experts_idx].t()
local_fc1_lst = torch.chunk(local_fc1_weight, self.tp_size, dim=0)
local_fc2_lst = torch.chunk(local_fc2_weight, self.tp_size, dim=1)
for tp_rank in range(self.tp_size):
mg_model[ep_rank][tp_rank][local_fc1_key] = local_fc1_lst[tp_rank].clone()
mg_model[ep_rank][tp_rank][local_fc2_key] = local_fc2_lst[tp_rank].clone()
if self.qlora_nf4:
self.qlora_nf4_quant(mg_model, ep_rank, tp_rank, local_fc1_key,
local_fc1_lst[tp_rank].clone())
self.qlora_nf4_quant(mg_model, ep_rank, tp_rank, local_fc2_key,
local_fc2_lst[tp_rank].clone())
def generate_pp_local_layer_idx(self):
"""generate each pp local layer index"""
pp_local_layer_idx = defaultdict()
for pp_rank in range(self.pp_size):
if self.num_layer_list is not None:
layer_list = list(map(int, self.num_layer_list.split(',')))
pp_local_layer_idx[pp_rank] = [i for i in range(layer_list[pp_rank])]
else:
pp_local_layer_idx[pp_rank] = [i for i in range(self.num_layers // self.pp_size)]
if self.noop_layers is not None:
noop_list = list(map(int, self.noop_layers.split(",")))
num_layers_each_pp = self.num_layers // self.pp_size
for num_noop_layers in noop_list:
pp_idx = num_noop_layers // num_layers_each_pp
local_noop_idx = num_noop_layers % num_layers_each_pp
pp_local_layer_idx[pp_idx].remove(local_noop_idx)
return pp_local_layer_idx
def generate_vpp_local_layer_idx(self):
vpp_local_layer_idx = defaultdict()
for pp_rank in range(self.pp_size):
vpp_local_layer_idx[pp_rank] = defaultdict()
for pp_rank in range(self.pp_size):
for vpp_rank in range(self.vpp_size):
vpp_local_layer_idx[pp_rank][vpp_rank] = [i for i in range(self.vpp_stage)]
if self.noop_layers is not None:
noop_list = list(map(int, self.noop_layers.split(",")))
num_layers_each_pp = self.num_layers // self.pp_size
if not self.dualpipe:
for num_noop_layer in noop_list:
pp_idx = num_noop_layer % (self.pp_size * self.vpp_stage) // self.vpp_stage
vpp_idx = num_noop_layer // self.vpp_stage // self.pp_size
local_noop_idx = num_noop_layer % num_layers_each_pp % self.vpp_stage
vpp_local_layer_idx[pp_idx][vpp_idx].remove(local_noop_idx)
else:
# calc pp rank, vpp rank and local idx of noop layer
for noop_layer in noop_list:
# e.g. pp2 noop5 [0 1 6 7 | 2 3 4 5] -> layer5: pp1 vpp1 local_idx1
# layer5 and layer2 are symmetrical, so they are in the same pp_rank.
# all layer are divided into two parts. layer5 is in last part. so vpp_rank=1
if noop_layer >= self.num_layers // 2:
mapping_layer = -(noop_layer - self.num_layers + 1)
vpp_idx = 1
pp_idx = mapping_layer // ((self.num_layers // 2) // self.pp_size)
local_noop_idx = self.vpp_stage - 1 - (mapping_layer - pp_idx * self.vpp_stage)
else:
vpp_idx = 0
pp_idx = noop_layer // ((self.num_layers // 2) // self.pp_size)
local_noop_idx = noop_layer - pp_idx * self.vpp_stage
vpp_local_layer_idx[pp_idx][vpp_idx].remove(local_noop_idx)
return vpp_local_layer_idx
def run(self):
"""save magetron format checkpoint"""
pp_local_layer_idx = self.generate_pp_local_layer_idx()
save_model_path = self.mg_path_process(self.mg_save_path)
if self.vpp_stage is None:
for pp_rank in range(self.pp_size):
mg_model = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
pp_weights = self.load_matched_hf_weights(pp_rank)
if pp_rank == 0:
self.set_model_preprocess(pp_weights, mg_model)
layer_list = self.pprank_layer_idxs[pp_rank]
if self.mtp_num_layers and pp_rank == self.pp_size - 1:
layer_list.sort()
mtp_layer_list = [layer_list.pop() for _ in range(self.mtp_num_layers)]
local_mtp_idx = 0
for mtp_layer in mtp_layer_list:
logger.info(f"Converting the weights of mtp layer {mtp_layer}.")
self.set_mtp_preprocess(mtp_layer, local_mtp_idx, pp_weights, mg_model)
self.set_model_layer_norm(mtp_layer, local_mtp_idx, pp_weights, mg_model, mtp_layer_flag=True)
self.set_model_layer_attn(mtp_layer, local_mtp_idx, pp_weights, mg_model, mtp_layer_flag=True)
self.set_model_layer_mlp(mtp_layer, local_mtp_idx, pp_weights, mg_model, mtp_layer_flag=True)
self.set_mtp_postprocess(mtp_layer, local_mtp_idx, pp_weights, mg_model)
local_mtp_idx += 1
local_idx = 0
cur_pp_local_idx = pp_local_layer_idx[pp_rank]
for hf_layer in layer_list:
logger.info(f"Converting the weights of layer {hf_layer}.")
local_layer_idx = cur_pp_local_idx[local_idx]
self.set_model_layer_norm(hf_layer, local_layer_idx, pp_weights, mg_model)
self.set_model_layer_attn(hf_layer, local_layer_idx, pp_weights, mg_model)
self.set_model_layer_mlp(hf_layer, local_layer_idx, pp_weights, mg_model)
local_idx += 1
if pp_rank == self.pp_size - 1:
self.set_model_postprocess(pp_weights, mg_model)
for ep_rank in range(self.ep_size):
for tp_rank in range(self.tp_size):
save_prefix = self.generate_mg_weights_dir(tp_rank=tp_rank, pp_rank=pp_rank, ep_rank=ep_rank)
parallel_save_path = os.path.join(save_model_path, save_prefix)
os.makedirs(parallel_save_path)
save_file_name = os.path.join(parallel_save_path, "model_optim_rng.pt")
logger.info(f"Saving to {save_file_name}")
torch.save({"model": mg_model[ep_rank][tp_rank], "checkpoint_version": 3.0, "iteration": 1},
save_file_name, pickle_protocol=4, _use_new_zipfile_serialization=True)
else:
vpp_local_layer_idx = self.generate_vpp_local_layer_idx()
for pp_rank in range(self.pp_size):
mg_model = defaultdict()
for vpp_rank in range(self.vpp_size):
pp_weights = self.load_matched_hf_weights(pp_rank, vpp_rank)
mg_model[vpp_rank] = defaultdict(lambda: defaultdict(lambda: defaultdict(dict)))
vpp_list = self.vpprank_layer_idxs[pp_rank][vpp_rank]
if pp_rank == 0 and vpp_rank == 0:
self.set_model_preprocess(pp_weights, mg_model[vpp_rank])
if self.dualpipe and pp_rank == 0 and vpp_rank == self.vpp_size - 1:
self.set_model_postprocess(pp_weights, mg_model[vpp_rank])
if self.mtp_num_layers:
dualpipe_mtp_flag = self.dualpipe and pp_rank == 0 and vpp_rank == self.vpp_size - 1
norm_mtp_flag = not self.dualpipe and pp_rank == self.pp_size - 1 and vpp_rank == self.vpp_size - 1
if dualpipe_mtp_flag or norm_mtp_flag:
vpp_list.sort()
mtp_layer_list = [vpp_list.pop() for _ in range(self.mtp_num_layers)]
local_mtp_idx = 0
for mtp_layer in mtp_layer_list:
logger.info(f"Converting the weights of mtp layer {mtp_layer}.")
self.set_mtp_preprocess(mtp_layer, local_mtp_idx, pp_weights, mg_model[vpp_rank])
self.set_model_layer_norm(mtp_layer, local_mtp_idx, pp_weights, mg_model[vpp_rank],
mtp_layer_flag=True)
self.set_model_layer_attn(mtp_layer, local_mtp_idx, pp_weights, mg_model[vpp_rank],
mtp_layer_flag=True)
self.set_model_layer_mlp(mtp_layer, local_mtp_idx, pp_weights, mg_model[vpp_rank],
mtp_layer_flag=True)
self.set_mtp_postprocess(mtp_layer, local_mtp_idx, pp_weights, mg_model[vpp_rank])
local_mtp_idx += 1
local_idx = 0
cur_vpp_local_idx = vpp_local_layer_idx[pp_rank][vpp_rank]
for hf_layer in vpp_list:
logger.info(f"Converting the weights of layer {hf_layer}.")
local_layer_idx = cur_vpp_local_idx[local_idx]
self.set_model_layer_norm(hf_layer, local_layer_idx, pp_weights, mg_model[vpp_rank])
self.set_model_layer_attn(hf_layer, local_layer_idx, pp_weights, mg_model[vpp_rank])
self.set_model_layer_mlp(hf_layer, local_layer_idx, pp_weights, mg_model[vpp_rank])
local_idx += 1
if not self.dualpipe and pp_rank == self.pp_size - 1 and vpp_rank == self.vpp_size - 1:
self.set_model_postprocess(pp_weights, mg_model[vpp_rank])
for ep_rank in range(self.ep_size):
for tp_rank in range(self.tp_size):
save_prefix = self.generate_mg_weights_dir(tp_rank=tp_rank, pp_rank=pp_rank, ep_rank=ep_rank)
parallel_save_path = os.path.join(save_model_path, save_prefix)
os.makedirs(parallel_save_path, exist_ok=True)
save_file_name = os.path.join(parallel_save_path, "model_optim_rng.pt")
logger.info(f"Saving to {save_file_name}")
model_dict = {"checkpoint_version": 3.0, "iteration": 1}
for vpp_rank in range(self.vpp_size):
model_key = f"model{vpp_rank}"
model_dict[model_key] = mg_model[vpp_rank][ep_rank][tp_rank]
torch.save(model_dict, save_file_name, pickle_protocol=4, _use_new_zipfile_serialization=True)
logger.info("Done!")
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--load-dir', type=str, required=True,
help='Directory to load model checkpoint from')
parser.add_argument('--save-dir', type=str, required=True,
help='Directory to save model checkpoint to')
parser.add_argument('--target-tensor-parallel-size', type=int, default=1,
help='Target tensor model parallel size, defaults to 1.')
parser.add_argument('--target-pipeline-parallel-size', type=int, default=1,
help='Target pipeline model parallel size, defaults to 1.')
parser.add_argument('--target-expert-parallel-size', type=int, default=1,
help='Target expert model parallel size, defaults to 1.')
parser.add_argument('--num-layers-per-virtual-pipeline-stage', type=int, default=None,
help='Number of layers per virtual pipeline stage')
parser.add_argument('--moe-grouped-gemm', action='store_true',
help='Use moe grouped gemm.')
parser.add_argument("--noop-layers", type=str, default=None, help='Specity the noop layers.')
parser.add_argument('--mtp-num-layers', type=int, default=0, help='Multi-Token prediction layer num')
parser.add_argument('--num-layer-list', type=str,
help='a list of number of layers, separated by comma; e.g., 4,4,4,4')
parser.add_argument('--num-layers', type=int, default=61,
help='Number of transformer layers.')
parser.add_argument('--first-k-dense-replace', type=int, default=3,
help='Customizing the number of dense layers.')
parser.add_argument("--moe-tp-extend-ep", action='store_true',
help="use tp group to extend experts parallism instead of sharding weight tensor of experts in tp group")
parser.add_argument('--mla-mm-split', action='store_true', default=False,
help='Split 2 up-proj matmul into 4 in MLA')
parser.add_argument('--schedules-method', type=str, default=None, choices=['dualpipev'],
help='An innovative bidirectional pipeline parallelism algorithm.')
parser.add_argument('--qlora-nf4', action='store_true',
help='use bitsandbytes nf4 to quantize model.')
args, _ = parser.parse_known_args()
return args
def main():
args = get_args()
logger.info(f"Arguments: {args}")
converter = CkptConvert(
hf_model_path=args.load_dir,
mg_save_path=args.save_dir,
num_layers=args.num_layers,
tp_size=args.target_tensor_parallel_size,
pp_size=args.target_pipeline_parallel_size,
ep_size=args.target_expert_parallel_size,
num_dense_layers=args.first_k_dense_replace,
num_layer_list=args.num_layer_list,
noop_layers=args.noop_layers,
moe_grouped_gemm=args.moe_grouped_gemm,
moe_tp_extend_ep=args.moe_tp_extend_ep,
mla_mm_split=args.mla_mm_split,
dualpipe=args.schedules_method,
mtp_num_layers=args.mtp_num_layers,
qlora_nf4=args.qlora_nf4,
vpp_stage=args.num_layers_per_virtual_pipeline_stage
)
converter.run()
if __name__ == '__main__':
main()