Skip to content

Commit 3e03356

Browse files
committed
add functional test
Signed-off-by: Yuki Huang <yukih@nvidia.com>
1 parent 289e1cc commit 3e03356

File tree

1 file changed

+49
-0
lines changed

1 file changed

+49
-0
lines changed

tests/functional/grpo_topp_topk.sh

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#!/bin/bash
2+
3+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd)
4+
PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..)
5+
# Mark the current repo as safe, since wandb fetches metadata about the repo
6+
git config --global --add safe.directory $PROJECT_ROOT
7+
8+
set -eou pipefail
9+
10+
EXP_NAME=$(basename $0 .sh)
11+
EXP_DIR=$SCRIPT_DIR/$EXP_NAME
12+
LOG_DIR=$EXP_DIR/logs
13+
JSON_METRICS=$EXP_DIR/metrics.json
14+
RUN_LOG=$EXP_DIR/run.log
15+
export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-}
16+
17+
rm -rf $EXP_DIR $LOG_DIR
18+
mkdir -p $EXP_DIR $LOG_DIR
19+
20+
cd $PROJECT_ROOT
21+
uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \
22+
$PROJECT_ROOT/examples/run_grpo.py \
23+
policy.model_name=Qwen/Qwen3-0.6B \
24+
grpo.num_prompts_per_step=2 \
25+
grpo.num_generations_per_prompt=4 \
26+
policy.train_global_batch_size=4 \
27+
policy.train_micro_batch_size=1 \
28+
policy.generation.temperature=0.8 \
29+
policy.generation.top_p=0.9 \
30+
policy.generation.top_k=50 \
31+
cluster.gpus_per_node=2 \
32+
grpo.max_num_steps=2 \
33+
logger.tensorboard_enabled=true \
34+
logger.log_dir=$LOG_DIR \
35+
logger.wandb_enabled=false \
36+
logger.monitor_gpus=true \
37+
checkpointing.enabled=false \
38+
$@ \
39+
2>&1 | tee $RUN_LOG
40+
41+
uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS
42+
43+
uv run tests/check_metrics.py $JSON_METRICS \
44+
'max(data["train/token_mult_prob_error"]) < 1.05' \
45+
'max(data["train/gen_kl_error"]) < 0.03' \
46+
'min(data["train/probs_ratio_clamped_min"]) > 0.79' \
47+
'max(data["train/probs_ratio_clamped_min"]) < 1.21' \
48+
'min(data["train/probs_ratio_clamped_max"]) > 0.79' \
49+
'max(data["train/probs_ratio_clamped_max"]) < 1.21'

0 commit comments

Comments
 (0)