Skip to content

Commit b6cd6b7

Browse files
authored
Use Mcore GPTModel (verl-project#706)
Use official GPTModel in megatron worker, supporting actor and critic workers.
1 parent 4fec38c commit b6cd6b7

File tree

15 files changed

+1494
-296
lines changed

15 files changed

+1494
-296
lines changed

scripts/model_merger.py

Lines changed: 135 additions & 139 deletions
Original file line numberDiff line numberDiff line change
@@ -238,156 +238,129 @@ def process_one_shard(shard_dir):
238238
if args.test:
239239
ref_state_dict = load_file(os.path.join(args.test_hf_dir, 'model.safetensors'))
240240

241-
def handle_qkv_proj(key, config, tensor, state_dict):
242-
nonlocal tp_size
243-
244-
hidden_size_per_head = config.hidden_size // config.num_attention_heads
245-
246-
if config.num_key_value_heads >= tp_size:
247-
q_size_tp = config.hidden_size // tp_size
248-
kv_size_tp = hidden_size_per_head * config.num_key_value_heads // tp_size
249-
total_size = q_size_tp + 2 * kv_size_tp
250-
q_part = tensor[:q_size_tp]
251-
k_part = tensor[q_size_tp:q_size_tp + kv_size_tp]
252-
v_part = tensor[q_size_tp + kv_size_tp:total_size]
253-
else:
254-
q_size_tp = config.hidden_size // tp_size
255-
kv_size_tp = hidden_size_per_head
256-
total_size = q_size_tp + 2 * kv_size_tp
257-
q_part = tensor[:q_size_tp]
258-
k_part = tensor[q_size_tp:q_size_tp + kv_size_tp]
259-
v_part = tensor[q_size_tp + kv_size_tp:total_size]
260-
261-
preffix = '.'.join(key.split('.')[:4])
262-
suffix = '.'.join(key.split('.')[5:])
263-
if state_dict.get(f'{preffix}.q_proj.{suffix}') is None:
264-
state_dict[f'{preffix}.q_proj.{suffix}'] = q_part
265-
else:
266-
state_dict[f'{preffix}.q_proj.{suffix}'] = torch.concat([state_dict[f'{preffix}.q_proj.{suffix}'], q_part], dim=0)
267-
if state_dict.get(f'{preffix}.k_proj.{suffix}') is None:
268-
state_dict[f'{preffix}.k_proj.{suffix}'] = k_part
269-
else:
270-
state_dict[f'{preffix}.k_proj.{suffix}'] = torch.concat([state_dict[f'{preffix}.k_proj.{suffix}'], k_part], dim=0)
271-
if state_dict.get(f'{preffix}.v_proj.{suffix}') is None:
272-
state_dict[f'{preffix}.v_proj.{suffix}'] = v_part
241+
def merge_across_tp(key, tp_data):
242+
if "linear_fc1.weight" in key:
243+
# if the tensor is gate and proj
244+
gate_lst = []
245+
up_lst = []
246+
for infer_param in tp_data:
247+
gate, up = infer_param.chunk(2)
248+
gate_lst.append(gate)
249+
up_lst.append(up)
250+
gate = torch.cat(gate_lst, dim=0)
251+
up = torch.cat(up_lst, dim=0)
252+
tp_data = [gate, up]
253+
elif "self_attention.linear_qkv." in key and 'layer_norm' not in key:
254+
# if the tensor is qkv, for each param on tp, split into q, k, v
255+
# concat q, k, v separately.
256+
q_lst = []
257+
k_lst = []
258+
v_lst = []
259+
assert config.num_attention_heads % config.num_key_value_heads == 0
260+
num_q_per_kv = config.num_attention_heads // config.num_key_value_heads
261+
assert tp_data[0].shape[0] % (num_q_per_kv + 2) == 0
262+
kv_size_per_tp = tp_data[0].shape[0] // (num_q_per_kv + 2)
263+
split_size = [kv_size_per_tp * num_q_per_kv, kv_size_per_tp, kv_size_per_tp]
264+
for infer_param in tp_data:
265+
num_query_groups_per_partition = config.num_key_value_heads // tp_size
266+
for chunk in infer_param.chunk(num_query_groups_per_partition):
267+
split_size = [
268+
kv_size_per_tp * num_q_per_kv // num_query_groups_per_partition,
269+
kv_size_per_tp // num_query_groups_per_partition,
270+
kv_size_per_tp // num_query_groups_per_partition
271+
]
272+
q, k, v = chunk.split(split_size)
273+
q_lst.append(q)
274+
k_lst.append(k)
275+
v_lst.append(v)
276+
q = torch.cat(q_lst, dim=0)
277+
k = torch.cat(k_lst, dim=0)
278+
v = torch.cat(v_lst, dim=0)
279+
280+
tp_data = [q,k,v]
281+
282+
elif "layer_norm" in key or "layernorm" in key or "output_layer" in key and args.is_value_model:
283+
tp_data = tp_data[0]
273284
else:
274-
state_dict[f'{preffix}.v_proj.{suffix}'] = torch.concat([state_dict[f'{preffix}.v_proj.{suffix}'], v_part], dim=0)
285+
dim = 0
286+
if "linear_fc2.weight" in key or "self_attention.linear_proj" in key:
287+
dim = 1
288+
tp_data = torch.cat(tp_data, dim=dim)
275289

276-
return state_dict
277290

278-
def handle_gate_up_proj(key, config, tensor, state_dict):
279-
nonlocal tp_size
280-
281-
intermediate_size_tp = config.intermediate_size // tp_size
282-
gate_weight_tp = tensor[:intermediate_size_tp]
283-
up_weight_tp = tensor[intermediate_size_tp:]
284-
preffix = '.'.join(key.split('.')[:4])
285-
suffix = '.'.join(key.split('.')[5:])
286-
if state_dict.get(f'{preffix}.gate_proj.{suffix}') is None:
287-
state_dict[f'{preffix}.gate_proj.{suffix}'] = gate_weight_tp
288-
else:
289-
state_dict[f'{preffix}.gate_proj.{suffix}'] = torch.concat([state_dict[f'{preffix}.gate_proj.{suffix}'], gate_weight_tp], dim=0)
290-
if state_dict.get(f'{preffix}.up_proj.{suffix}') is None:
291-
state_dict[f'{preffix}.up_proj.{suffix}'] = up_weight_tp
292-
else:
293-
state_dict[f'{preffix}.up_proj.{suffix}'] = torch.concat([state_dict[f'{preffix}.up_proj.{suffix}'], up_weight_tp], dim=0)
294-
295-
return state_dict
296-
297-
def merge_between_tp_rank(key, model_state_dict):
298-
nonlocal state_dict
299-
300-
try:
301-
tensor = model_state_dict.pop(key)
302-
except:
303-
raise RuntimeError(f"error pop: {key}")
304-
# Embedding layer
305-
if "model.embed_tokens.weight" in key:
306-
if state_dict[key] is None:
307-
state_dict[key] = tensor
308-
else:
309-
state_dict[key] = torch.concat([state_dict[key], tensor], dim=0)
310-
return state_dict
311-
# Tranformer Layers
312-
if "input_layernorm.weight" in key:
313-
if state_dict[key] is None:
314-
state_dict[key] = tensor
315-
return state_dict
316-
if re.search(r"self_attn\.qkv_proj", key):
317-
state_dict = handle_qkv_proj(key, config, tensor, state_dict)
318-
return state_dict
319-
if "self_attn.o_proj.weight" in key:
320-
if state_dict[key] is None:
321-
state_dict[key] = tensor
322-
else:
323-
state_dict[key] = torch.concat([state_dict[key], tensor], dim=1)
324-
return state_dict
325-
if "post_attention_layernorm.weight" in key:
326-
if state_dict[key] is None:
327-
state_dict[key] = tensor
328-
return state_dict
329-
if re.search(r"mlp\.gate_up_proj\.weight", key):
330-
state_dict = handle_gate_up_proj(key, config, tensor, state_dict)
331-
return state_dict
332-
if "mlp.down_proj.weight" in key:
333-
if state_dict[key] is None:
334-
state_dict[key] = tensor
335-
else:
336-
state_dict[key] = torch.concat([state_dict[key], tensor], dim=1)
337-
return state_dict
338-
# Final LayerNorm
339-
if "model.norm.weight" in key:
340-
if state_dict[key] is None:
341-
state_dict[key] = tensor
342-
return state_dict
343-
if not args.tie_word_embedding:
344-
if args.is_value_model:
345-
if "lm_head.weight" in key:
346-
if state_dict[key] is None:
347-
state_dict[key] = tensor
348-
if "reward_head.weight" in key:
349-
if state_dict[key] is None:
350-
state_dict[key] = tensor
351-
else:
352-
if "lm_head.weight" in key:
353-
if state_dict[key] is None:
354-
state_dict[key] = tensor
355-
else:
356-
state_dict[key] = torch.concat([state_dict[key], tensor], dim=0)
357-
return state_dict
358-
return state_dict
291+
return tp_data
359292

360-
for pp_rank in range(pp_size):
361-
print(f'pp_rank: {pp_rank}')
362-
for vpp_rank, state_dict_single_layer in enumerate(model_state_dict_lst[pp_rank][0]):
363-
state_dict_single_layer_iter = state_dict_single_layer.copy()
364-
keys = state_dict_single_layer_iter.keys()
293+
vpp_size = len(model_state_dict_lst[0][0])
294+
layers_cum = 0
295+
for vpp_rank in range(vpp_size):
296+
for pp_rank in range(pp_size):
297+
layers_handled = 0
298+
keys = model_state_dict_lst[pp_rank][0][vpp_rank].keys()
365299
for key in keys:
366300
if "extra_state" in key:
367301
continue
368-
if args.tie_word_embedding and ("lm_head" in key or "reward_head" in key):
302+
if args.tie_word_embedding and ("output_layer" in key):
369303
print(f'skip lm_head and reward_head loading because of tie_word_embeddings')
370304
continue
371-
if re.search(r"self_attn\.qkv_proj", key) is None and re.search(r"gate_up_proj", key) is None:
372-
state_dict[key] = None
373-
for tp_rank in range(tp_size):
374-
model_state_dict = model_state_dict_lst[pp_rank][tp_rank][vpp_rank]
375-
state_dict = merge_between_tp_rank(key, model_state_dict)
376-
305+
new_key = key
306+
if "decoder.layers." in key:
307+
local_layer_no = int(key.split('.')[2])
308+
layers_handled = max(local_layer_no, layers_handled)
309+
global_layer_no = local_layer_no + layers_cum
310+
new_key_list=key.split('.')
311+
new_key_list[2] = str(global_layer_no)
312+
new_key = '.'.join(new_key_list)
313+
314+
tp_data = [model_state_dict_lst[pp_rank][tp_rank][vpp_rank][key] for tp_rank in range(tp_size)]
315+
merged = merge_across_tp(new_key, tp_data)
316+
if not isinstance(merged,list):
317+
state_dict[new_key] = merged
318+
elif len(merged)==3:
319+
# split qkv
320+
for n,d in zip(['q','k','v'], merged):
321+
state_dict[new_key.replace("linear_qkv",f"linear_{n}")] = d
322+
elif len(merged)==2:
323+
# split gate up
324+
state_dict[new_key.replace("linear_fc1","gate_proj")] = merged[0]
325+
state_dict[new_key.replace("linear_fc1","up_proj")] = merged[1]
326+
layers_cum += layers_handled+1 # zero based
327+
377328
del model_state_dict_lst
329+
330+
params_mapping = [
331+
# (megatron core gpt model name, vllm model name)
332+
("self_attention.linear_qkv.layer_norm_weight", "input_layernorm.weight"),
333+
("self_attention.linear_qkv.layer_norm_bias", "input_layernorm.bias"),
334+
("embedding.word_embeddings", "model.embed_tokens"),
335+
("self_attention.linear_qkv", "self_attn.qkv_proj"),
336+
("self_attention.linear_proj", "self_attn.o_proj"),
337+
("pre_mlp_layernorm", "post_attention_layernorm"),
338+
("mlp.linear_fc1.layer_norm_weight", "post_attention_layernorm.weight"),
339+
("mlp.linear_fc1.layer_norm_bias", "post_attention_layernorm.bias"),
340+
("mlp.linear_fc1", "mlp.gate_up_proj"),
341+
("mlp.linear_fc2", "mlp.down_proj"),
342+
("decoder.final_layernorm", "model.norm"),
343+
("output_layer", "lm_head"),
344+
("self_attention.linear_q", "self_attn.q_proj"),
345+
("self_attention.linear_k", "self_attn.k_proj"),
346+
("self_attention.linear_v", "self_attn.v_proj"),
347+
]
348+
378349
if args.test:
379-
for key, value in state_dict.items():
380-
print(key)
381-
if key not in ref_state_dict:
382-
raise RuntimeError(f'key: {key} not exist in ref_state_dict {value}')
383-
if value.shape != ref_state_dict[key].shape:
384-
raise RuntimeError(f'key: {key} shape mismatch {value.shape}, {ref_state_dict[key].shape}')
385-
assert value.dtype == ref_state_dict[key].dtype, f'{key} state_dict[key].dtype: {value.dtype} != ref_state_dict[key].dtype: {ref_state_dict[key].dtype}'
386-
torch.testing.assert_close(value, ref_state_dict[key], atol=1e-4, rtol=1e-4)
387-
for key in ref_state_dict:
388-
if key not in state_dict:
389-
raise RuntimeError(f'key: {key} not exist in state_dict {ref_state_dict[key]}')
390-
350+
351+
for original_name, loaded_weight in state_dict.items():
352+
name = _replace_name(original_name, params_mapping)
353+
if not name or name.endswith(".bias") and name not in ref_state_dict:
354+
continue
355+
if "rotary_emb.inv_freq" in name:
356+
continue
357+
if args.tie_word_embedding and "lm_head.weight" in name:
358+
continue
359+
if name not in ref_state_dict:
360+
raise RuntimeError(f'key: {name} not exist in state_dict')
361+
param = ref_state_dict[name]
362+
assert loaded_weight.dtype == param.dtype
363+
torch.testing.assert_close(loaded_weight, param, atol=1e-4, rtol=1e-4)
391364

392365
print('Writing to local disk')
393366
if args.target_dir is None:
@@ -415,6 +388,29 @@ def merge_between_tp_rank(key, model_state_dict):
415388
if args.hf_upload_path:
416389
upload_model_to_huggingface(hf_path)
417390

391+
392+
def _replace_name(megatron_name, name_mapping):
393+
for m_name, v_name in name_mapping:
394+
if m_name not in megatron_name:
395+
continue
396+
if "layers" in megatron_name: # deal with decoder layers
397+
megatron_name = megatron_name.replace("decoder", "model")
398+
megatron_name_list = megatron_name.split(".")
399+
if "layer_norm_weight" in megatron_name_list or "layer_norm_bias" in megatron_name_list:
400+
param_name_list = megatron_name_list[:3]
401+
param_name_list.append(v_name)
402+
param_name = ".".join(param_name_list)
403+
else:
404+
param_name_list = megatron_name_list[:3]
405+
weight_or_bias = megatron_name_list[-1]
406+
param_name_list.append(v_name)
407+
param_name_list.append(weight_or_bias)
408+
param_name = ".".join(param_name_list)
409+
return param_name
410+
else:
411+
param_name = megatron_name.replace(m_name, v_name)
412+
return param_name
413+
418414
if __name__ == '__main__':
419415
if args.backend == "fsdp":
420416
convert_fsdp_checkpoints_to_hfmodels()

verl/models/llama/megatron/layers/parallel_linear.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,34 @@ def __init__(self,
7272
gather_output=gather_output,
7373
skip_bias_add=skip_bias_add,
7474
**kwargs)
75+
76+
77+
import torch
78+
79+
80+
class LinearForLastLayer(torch.nn.Linear):
81+
82+
def __init__(
83+
self,
84+
input_size,
85+
output_size,
86+
*,
87+
config,
88+
bias=True,
89+
):
90+
super().__init__(in_features=input_size, out_features=output_size, bias=bias)
91+
self.sequence_parallel = config.sequence_parallel
92+
if self.sequence_parallel:
93+
setattr(self.weight, 'sequence_parallel', True)
94+
95+
def forward(
96+
self,
97+
input_,
98+
weight=None,
99+
runtime_gather_output=None,
100+
):
101+
logits = super().forward(input_)
102+
logits = logits.float()
103+
if self.sequence_parallel:
104+
logits = tensor_parallel.gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)
105+
return logits, None

verl/models/mcore/__init__.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright 2025 Bytedance Ltd. and/or its affiliates
2+
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from .gpt_model import gptmodel_forward

0 commit comments

Comments
 (0)