Skip to content

Commit 441f745

Browse files
authored
feat: Add GPT-OSS support via mcore (#1452)
Signed-off-by: ashors1 <[email protected]>
1 parent 7dd9a01 commit 441f745

File tree

4 files changed

+81
-0
lines changed

4 files changed

+81
-0
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
defaults: ../../grpo_math_1B.yaml
2+
grpo:
3+
num_prompts_per_step: 64
4+
num_generations_per_prompt: 32
5+
loss_fn:
6+
use_importance_sampling_correction: true
7+
policy:
8+
model_name: openai/gpt-oss-20b
9+
train_micro_batch_size: 1
10+
max_total_sequence_length: 4096
11+
megatron_cfg:
12+
enabled: true
13+
expert_model_parallel_size: 8
14+
tensor_model_parallel_size: 4
15+
sequence_parallel: true
16+
moe_permute_fusion: true
17+
dtensor_cfg:
18+
enabled: false
19+
sequence_packing:
20+
enabled: false
21+
generation:
22+
vllm_cfg:
23+
tensor_parallel_size: 2
24+
cluster:
25+
num_nodes: 8
26+
gpus_per_node: 8

nemo_rl/models/generation/vllm/vllm_worker.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import ray
2222
import torch
23+
from transformers import AutoConfig
2324

2425
from nemo_rl.distributed.batched_data_dict import BatchedDataDict
2526
from nemo_rl.distributed.worker_group_utils import get_nsight_config_if_pattern_matches
@@ -305,6 +306,17 @@ def _patch_vllm_init_workers_ray():
305306
self.cfg["vllm_cfg"].get("hf_overrides", {}) or {}
306307
)
307308

309+
# Override HF config for gpt-oss models to ensure compatibility with megatron
310+
# The megatron --> hf export is done in bf16, so we disable quantization
311+
hf_config = AutoConfig.from_pretrained(self.model_name, trust_remote_code=True)
312+
if "GptOssForCausalLM" in getattr(hf_config, "architectures", []):
313+
if "quantization_config" in hf_config:
314+
assert load_format == "dummy", (
315+
"Loading quantized GPT-OSS models is currently only supported with load_format='dummy'."
316+
)
317+
# disable quantization
318+
vllm_kwargs["hf_overrides"]["quantization_config"] = {}
319+
308320
llm_kwargs = dict(
309321
model=self.model_name,
310322
served_model_name=self.model_name,
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#!/bin/bash
2+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
3+
source $SCRIPT_DIR/common.env
4+
5+
# ===== BEGIN CONFIG =====
6+
NUM_NODES=8
7+
STEPS_PER_RUN=60
8+
MAX_STEPS=60
9+
NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up
10+
NUM_MINUTES=240
11+
# ===== END CONFIG =====
12+
13+
exit_if_max_steps_reached
14+
15+
# Run the experiment
16+
cd $PROJECT_ROOT
17+
uv run examples/run_grpo_math.py \
18+
--config $CONFIG_PATH \
19+
grpo.max_num_steps=$MAX_STEPS \
20+
logger.log_dir=$LOG_DIR \
21+
logger.wandb_enabled=True \
22+
logger.wandb.project=nemo-rl \
23+
logger.wandb.name=$EXP_NAME \
24+
logger.monitor_gpus=True \
25+
logger.tensorboard_enabled=True \
26+
checkpointing.enabled=True \
27+
checkpointing.checkpoint_dir=$CKPT_DIR \
28+
"$@" \
29+
2>&1 | tee $RUN_LOG
30+
31+
# Convert tensorboard logs to json
32+
uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS
33+
34+
# Only run metrics if the target step is reached
35+
if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then
36+
uv run tests/check_metrics.py $JSON_METRICS \
37+
'mean(data["train/gen_kl_error"]) < 0.002' \
38+
'data["train/reward"]["60"] > 0.60' \
39+
'mean(data["timing/train/total_step_time"], -6, -1) < 210'
40+
fi

tests/test_suites/release.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ tests/test_suites/llm/dapo-qwen2.5-7b.sh
2323
# Deepseek-V3 on DAPO dataset
2424
tests/test_suites/llm/grpo-dapomath17k-dsv3-megatron.sh
2525

26+
# GPT-OSS
27+
tests/test_suites/llm/grpo-gptoss-20b-8n8g-megatron.sh
28+
2629
#######
2730
# SFT #
2831
#######

0 commit comments

Comments
 (0)