|
| 1 | +#!/bin/bash |
| 2 | +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) |
| 3 | +source $SCRIPT_DIR/common.env |
| 4 | + |
| 5 | +# ===== BEGIN CONFIG ===== |
| 6 | +NUM_NODES=1 |
| 7 | +STEPS_PER_RUN=40 |
| 8 | +MAX_STEPS=40 |
| 9 | +NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up |
| 10 | +NUM_MINUTES=240 |
| 11 | +# ===== END CONFIG ===== |
| 12 | + |
| 13 | +exit_if_max_steps_reached |
| 14 | + |
| 15 | +# Run the experiment |
| 16 | +cd $PROJECT_ROOT |
| 17 | +uv run examples/run_grpo_math.py \ |
| 18 | + --config $CONFIG_PATH \ |
| 19 | + grpo.max_num_steps=$MAX_STEPS \ |
| 20 | + logger.log_dir=$LOG_DIR \ |
| 21 | + logger.wandb_enabled=True \ |
| 22 | + logger.wandb.project=nemo-rl \ |
| 23 | + logger.wandb.name=$EXP_NAME \ |
| 24 | + logger.monitor_gpus=True \ |
| 25 | + logger.tensorboard_enabled=True \ |
| 26 | + checkpointing.enabled=True \ |
| 27 | + checkpointing.checkpoint_dir=$CKPT_DIR \ |
| 28 | + $@ \ |
| 29 | + 2>&1 | tee $RUN_LOG |
| 30 | + |
| 31 | +# Convert tensorboard logs to json |
| 32 | +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS |
| 33 | + |
| 34 | +# Only run metrics if the target step is reached |
| 35 | +if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then |
| 36 | + uv run tests/check_metrics.py $JSON_METRICS \ |
| 37 | + 'mean(data["train/token_mult_prob_error"]) < 1.1' \ |
| 38 | + "data['train/token_mult_prob_error']['$MAX_STEPS'] < 1.1" |
| 39 | +fi |
| 40 | + |
| 41 | +# TODO: enable in subsequent PR to do a quick accuracy check |
| 42 | +## Convert 8k checkpoint |
| 43 | +#uv run examples/converters/convert_dcp_to_hf.py \ |
| 44 | +# --config=$CKPT_DIR/step_${MAX_STEPS}/config.yaml \ |
| 45 | +# --dcp-ckpt-path=$CKPT_DIR/step_${MAX_STEPS}/policy/weights \ |
| 46 | +# --hf-ckpt-path=$CKPT_DIR/grpo-deepscaler-8k-${MAX_STEPS}-hf |
| 47 | +# |
| 48 | +## Run eval |
| 49 | +#uv run examples/run_eval.py \ |
| 50 | +# generation.model_name=$CKPT_DIR/grpo-deepscaler-8k-${MAX_STEPS}-hf \ |
| 51 | +# data.prompt_file=examples/prompts/cot.txt \ |
| 52 | +# generation.vllm_cfg.max_model_len=32768 2>&1 | tee ${RUN_LOG}.aime-8k |
| 53 | +# |
| 54 | +#cat ${RUN_LOG}.aime-8k | grep "score=" | sed 's/.*score=\([^ ]*\).*/{"score": \1}/' > ${RUN_LOG}-8k-metric.json |
| 55 | +# |
| 56 | +#uv run tests/check_metrics.py ${RUN_LOG}-8k-metric.json \ |
| 57 | +# 'data["score"] >= 0.25' \ |
| 58 | +# |
| 59 | +##uv run examples/run_eval.py \ |
| 60 | +## generation.model_name=deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B \ |
| 61 | +## data.prompt_file=examples/prompts/cot.txt \ |
| 62 | +## generation.vllm_cfg.max_model_len=32768 2>&1 | tee ${RUN_LOG}.aime-baseline |
| 63 | +# |
| 64 | +##cat ${RUN_LOG}.aime-baseline | grep "score=" | sed 's/.*score=\([^ ]*\).*/{"score": \1}/' > ${RUN_LOG}-baseline-metric.json |
| 65 | +# |
| 66 | +##uv run tests/check_metrics.py ${RUN_LOG}-baseline-metric.json \ |
| 67 | +## 'data["score"] == 0.2' \ |
0 commit comments