@@ -912,8 +912,9 @@ def shard_low_precision_checkpoint(
912912 """
913913 assert tp_grain_size % 8 == 0 , "tp_grain_size must be a multiple of 8"
914914 num_heads = model_config ["num_attention_heads" ]
915+ num_kv_heads = num_heads
915916 if "num_key_value_heads" in model_config :
916- num_heads = model_config ["num_key_value_heads" ]
917+ num_kv_heads = model_config ["num_key_value_heads" ]
917918 local_rank = rank
918919
919920 mha_layers_split_by_N = [
@@ -923,6 +924,9 @@ def shard_low_precision_checkpoint(
923924 "q_b_proj" ,
924925 "kv_b_proj" ,
925926 ]
927+ qkv_proj_layers = [
928+ "qkv_proj" ,
929+ ]
926930 # mlp is split with grain size = tp_grain_size
927931 mlp_layers_split_by_N = [
928932 "gate_proj" ,
@@ -933,6 +937,9 @@ def shard_low_precision_checkpoint(
933937 "w1" ,
934938 "w3" ,
935939 ]
940+ gate_up_proj_layers = [
941+ "gate_up_proj" ,
942+ ]
936943 mha_layers_split_by_K = [
937944 "o_proj" ,
938945 "out_proj" ,
@@ -947,20 +954,28 @@ def shard_low_precision_checkpoint(
947954 "w2" ,
948955 ]
949956 lm_head_layers = ["lm_head" ] # split by K but not quantized
957+
958+ def _key_belongs_to (key , layer_group ):
959+ key_split = key .split ("." )
960+ for layer in layer_group :
961+ if layer in key_split :
962+ return True
963+ return False
964+
950965 low_precision_checkpoint_dict = low_precision_checkpoint .copy ()
951966 head_range = [0 ]
952- head_per_rank = num_heads // world_size
967+ head_per_rank = num_kv_heads // world_size
953968 for i in range (0 , world_size ):
954969 head_this_rank = head_per_rank
955- if i < num_heads % world_size :
970+ if i < num_kv_heads % world_size :
956971 head_this_rank += 1
957972 head_range .append (head_range [- 1 ] + head_this_rank )
958973 for key in low_precision_checkpoint .keys ():
959974 q_head_start = head_range [rank ]
960975 q_head_end = q_head_start + (head_range [rank + 1 ] - head_range [rank ])
961976 if "bias" in key :
962977 continue
963- if any ( substring in key for substring in mha_layers_split_by_N ):
978+ if _key_belongs_to ( key , mha_layers_split_by_N ):
964979 data = low_precision_checkpoint_dict [key ]
965980 if quantization_method == "awq" :
966981 # qweight shape: [K, N // 8]
@@ -1041,7 +1056,91 @@ def shard_low_precision_checkpoint(
10411056 ].contiguous ()
10421057 else :
10431058 raise AssertionError (f"{ quantization_method } is not supported yet." )
1044- elif any (substring in key for substring in mlp_layers_split_by_N ):
1059+ elif _key_belongs_to (key , qkv_proj_layers ):
1060+ # need to split q, k and v proj then shard them separately
1061+ # finally concat them together
1062+ # mha layer split by N
1063+ data = low_precision_checkpoint_dict [key ]
1064+ hidden_size = model_config ["hidden_size" ]
1065+ head_dim = hidden_size // num_heads
1066+ if quantization_method == "awq" :
1067+ # qweight shape: [K, N // 8]
1068+ # scales shape: [K // G, N]
1069+ # qzeros shape: [K // G, N // 8]
1070+ N_pack_factor = 1 if "scales" in key else 8
1071+ N = data .shape [- 1 ] * N_pack_factor
1072+ q_pos = N - 2 * num_kv_heads * head_dim
1073+ k_pos = q_pos + num_kv_heads * head_dim
1074+ v_pos = k_pos + num_kv_heads * head_dim
1075+ q_pos //= N_pack_factor
1076+ k_pos //= N_pack_factor
1077+ v_pos //= N_pack_factor
1078+ data_list = [
1079+ data [:, :q_pos ],
1080+ data [:, q_pos :k_pos ],
1081+ data [:, k_pos :v_pos ],
1082+ ]
1083+ for i in range (len (data_list )):
1084+ data = data_list [i ].contiguous ()
1085+ if data .shape [- 1 ] % head_range [- 1 ] == 0 :
1086+ dim = data .shape [- 1 ] // head_range [- 1 ]
1087+ else :
1088+ assert data .shape [- 1 ] % world_size == 0
1089+ dim = data .shape [- 1 ] // world_size
1090+ q_head_start = local_rank
1091+ q_head_end = local_rank + 1
1092+ data_list [i ] = data [
1093+ :, q_head_start * dim : q_head_end * dim
1094+ ].contiguous ()
1095+ low_precision_checkpoint_dict [key ] = torch .cat (
1096+ data_list , dim = - 1
1097+ ).contiguous ()
1098+ elif quantization_method == "gptq" or (
1099+ quantization_method == "rtn" and bits == 4
1100+ ):
1101+ # qweight shape: [K // 8, N]
1102+ # scales shape: [K // G, N]
1103+ # qzeros shape: [K // G, N // 8]
1104+ # g_idx shape: [K]
1105+ data_list = []
1106+ if "g_idx" not in key :
1107+ N_pack_factor = 8 if "qzeros" in key else 1
1108+ N = data .shape [- 1 ] * N_pack_factor
1109+ q_pos = N - 2 * num_kv_heads * head_dim
1110+ k_pos = q_pos + num_kv_heads * head_dim
1111+ v_pos = k_pos + num_kv_heads * head_dim
1112+ q_pos //= N_pack_factor
1113+ k_pos //= N_pack_factor
1114+ v_pos //= N_pack_factor
1115+ data_list = [
1116+ data [:, :q_pos ],
1117+ data [:, q_pos :k_pos ],
1118+ data [:, k_pos :v_pos ],
1119+ ]
1120+ for i in range (len (data_list )):
1121+ if "g_idx" in key :
1122+ continue
1123+ data = data_list [i ]
1124+ if data .shape [- 1 ] % head_range [- 1 ] == 0 :
1125+ dim = data .shape [- 1 ] // head_range [- 1 ]
1126+ else :
1127+ assert data .shape [- 1 ] % world_size == 0
1128+ dim = data .shape [- 1 ] // world_size
1129+ q_head_start = local_rank
1130+ q_head_end = local_rank + 1
1131+ data_list [i ] = data [
1132+ :, q_head_start * dim : q_head_end * dim
1133+ ].contiguous ()
1134+ if "g_idx" in key :
1135+ if not desc_act :
1136+ low_precision_checkpoint_dict .pop (key )
1137+ else :
1138+ low_precision_checkpoint_dict [key ] = torch .cat (
1139+ data_list , dim = - 1
1140+ ).contiguous ()
1141+ else :
1142+ raise AssertionError (f"{ quantization_method } is not supported yet." )
1143+ elif _key_belongs_to (key , mlp_layers_split_by_N ):
10451144 data = low_precision_checkpoint_dict [key ]
10461145 if quantization_method == "awq" :
10471146 # qweight shape: [K, N // 8]
@@ -1178,7 +1277,95 @@ def shard_low_precision_checkpoint(
11781277 ].contiguous ()
11791278 else :
11801279 raise AssertionError (f"{ quantization_method } is not supported yet." )
1181- elif any (substring in key for substring in mha_layers_split_by_K ):
1280+ elif _key_belongs_to (key , gate_up_proj_layers ):
1281+ # need to split gate and up proj then shard them separately
1282+ # finally concat them together
1283+ # mlp layer split by N
1284+ data = low_precision_checkpoint_dict [key ]
1285+ if quantization_method == "awq" :
1286+ # qweight shape: [K, N // 8]
1287+ # scales shape: [K // G, N]
1288+ # qzeros shape: [K // G, N // 8]
1289+ data_list = list (data .chunk (2 , dim = - 1 ))
1290+ for i in range (len (data_list )):
1291+ data = data_list [i ].contiguous ()
1292+ if "scales" in key :
1293+ assert (
1294+ data .shape [1 ] % tp_grain_size == 0
1295+ ), "N must be divisible by tp_grain_size"
1296+ grains = data .shape [1 ] // tp_grain_size
1297+ dim = tp_grain_size
1298+ else :
1299+ assert (
1300+ data .shape [1 ] * 8
1301+ ) % tp_grain_size == 0 , "N must be divisible by tp_grain_size"
1302+ grains = data .shape [1 ] // (tp_grain_size // 8 )
1303+ dim = tp_grain_size // 8
1304+ grains_per_rank = grains // world_size
1305+ grains_rem = grains % world_size
1306+ grains_start = grains_per_rank * local_rank + min (
1307+ local_rank , grains_rem
1308+ )
1309+ grains_end = (
1310+ grains_start
1311+ + grains_per_rank
1312+ + (1 if local_rank < grains_rem else 0 )
1313+ )
1314+ data_list [i ] = data [
1315+ :, grains_start * dim : grains_end * dim
1316+ ].contiguous ()
1317+ low_precision_checkpoint_dict [key ] = torch .cat (
1318+ data_list , dim = - 1
1319+ ).contiguous ()
1320+ elif quantization_method == "gptq" or (
1321+ quantization_method == "rtn" and bits == 4
1322+ ):
1323+ # qweight shape: [K // 8, N]
1324+ # scales shape: [K // G, N]
1325+ # qzeros shape: [K // G, N // 8]
1326+ # g_idx shape: [K]
1327+ data_list = list (data .chunk (2 , dim = - 1 ))
1328+ for i in range (len (data_list )):
1329+ if "g_idx" in key :
1330+ continue
1331+ data = data_list [i ]
1332+ if "qzeros" in key :
1333+ assert (
1334+ data .shape [- 1 ] * 8
1335+ ) % tp_grain_size == 0 , "N must be divisible by tp_grain_size"
1336+ grains = data .shape [- 1 ] // (tp_grain_size // 8 )
1337+ dim = tp_grain_size // 8
1338+ elif "g_idx" not in key : # qweight, scales
1339+ assert (
1340+ data .shape [- 1 ] % tp_grain_size == 0
1341+ ), "N must be divisible by tp_grain_size"
1342+ grains = data .shape [- 1 ] // tp_grain_size
1343+ dim = tp_grain_size
1344+ grains_per_rank = grains // world_size
1345+ grains_rem = grains % world_size
1346+ grains_start = grains_per_rank * local_rank + min (
1347+ local_rank , grains_rem
1348+ )
1349+ grains_end = (
1350+ grains_start
1351+ + grains_per_rank
1352+ + (1 if local_rank < grains_rem else 0 )
1353+ )
1354+ data_list [i ] = data [
1355+ :, grains_start * dim : grains_end * dim
1356+ ].contiguous ()
1357+ if "g_idx" in key :
1358+ if not desc_act :
1359+ low_precision_checkpoint_dict .pop (key )
1360+ else :
1361+ low_precision_checkpoint_dict [key ] = torch .cat (
1362+ data_list , dim = - 1
1363+ ).contiguous ()
1364+ else :
1365+ raise AssertionError (f"{ quantization_method } is not supported yet." )
1366+ elif _key_belongs_to (key , mha_layers_split_by_K ):
1367+ if "bias" in key :
1368+ continue
11821369 data = low_precision_checkpoint_dict [key ]
11831370 if ("scales" in key or "qzeros" in key ) and data .shape [0 ] == 1 :
11841371 continue
@@ -1269,7 +1456,7 @@ def shard_low_precision_checkpoint(
12691456 ]
12701457 else :
12711458 raise AssertionError (f"{ quantization_method } is not supported yet." )
1272- elif any ( substring in key for substring in mlp_layers_split_by_K ):
1459+ elif _key_belongs_to ( key , mlp_layers_split_by_K ):
12731460 data = low_precision_checkpoint_dict [key ]
12741461 if ("scales" in key or "qzeros" in key ) and data .shape [0 ] == 1 :
12751462 continue
@@ -1422,7 +1609,7 @@ def shard_low_precision_checkpoint(
14221609 ]
14231610 else :
14241611 raise AssertionError (f"{ quantization_method } is not supported yet." )
1425- elif any ( substring in key for substring in lm_head_layers ):
1612+ elif _key_belongs_to ( key , lm_head_layers ):
14261613 # lm_head: [N, K] (not quantized)
14271614 # Same for all quantization methods
14281615 data = low_precision_checkpoint_dict [key ]
0 commit comments