|
| 1 | +#!/bin/bash |
| 2 | + |
| 3 | +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) |
| 4 | +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) |
| 5 | +# Mark the current repo as safe, since wandb fetches metadata about the repo |
| 6 | +git config --global --add safe.directory $PROJECT_ROOT |
| 7 | + |
| 8 | +set -eou pipefail |
| 9 | + |
| 10 | +EXP_NAME=$(basename $0 .sh) |
| 11 | +EXP_DIR=$SCRIPT_DIR/$EXP_NAME |
| 12 | +LOG_DIR=$EXP_DIR/logs |
| 13 | +JSON_METRICS=$EXP_DIR/metrics.json |
| 14 | +RUN_LOG=$EXP_DIR/run.log |
| 15 | +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} |
| 16 | + |
| 17 | +rm -rf $EXP_DIR $LOG_DIR |
| 18 | +mkdir -p $EXP_DIR $LOG_DIR |
| 19 | + |
| 20 | +cd $PROJECT_ROOT |
| 21 | +uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ |
| 22 | + $PROJECT_ROOT/examples/run_grpo.py \ |
| 23 | + policy.model_name=Qwen/Qwen3-0.6B \ |
| 24 | + grpo.num_prompts_per_step=2 \ |
| 25 | + grpo.num_generations_per_prompt=4 \ |
| 26 | + policy.train_global_batch_size=4 \ |
| 27 | + policy.train_micro_batch_size=1 \ |
| 28 | + policy.generation.temperature=0.8 \ |
| 29 | + policy.generation.top_p=0.9 \ |
| 30 | + policy.generation.top_k=50 \ |
| 31 | + cluster.gpus_per_node=2 \ |
| 32 | + grpo.max_num_steps=2 \ |
| 33 | + logger.tensorboard_enabled=true \ |
| 34 | + logger.log_dir=$LOG_DIR \ |
| 35 | + logger.wandb_enabled=false \ |
| 36 | + logger.monitor_gpus=true \ |
| 37 | + checkpointing.enabled=false \ |
| 38 | + $@ \ |
| 39 | + 2>&1 | tee $RUN_LOG |
| 40 | + |
| 41 | +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS |
| 42 | + |
| 43 | +uv run tests/check_metrics.py $JSON_METRICS \ |
| 44 | + 'max(data["train/token_mult_prob_error"]) < 1.05' \ |
| 45 | + 'max(data["train/gen_kl_error"]) < 0.03' \ |
| 46 | + 'min(data["train/probs_ratio_clamped_min"]) > 0.79' \ |
| 47 | + 'max(data["train/probs_ratio_clamped_min"]) < 1.21' \ |
| 48 | + 'min(data["train/probs_ratio_clamped_max"]) > 0.79' \ |
| 49 | + 'max(data["train/probs_ratio_clamped_max"]) < 1.21' |
0 commit comments