|
| 1 | +#!/bin/bash |
| 2 | + |
| 3 | +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd) |
| 4 | +PROJECT_ROOT=$(realpath $SCRIPT_DIR/../..) |
| 5 | +# Mark the current repo as safe, since wandb fetches metadata about the repo |
| 6 | +git config --global --add safe.directory $PROJECT_ROOT |
| 7 | + |
| 8 | +set -eou pipefail |
| 9 | + |
| 10 | +EXP_NAME=$(basename $0 .sh) |
| 11 | +EXP_DIR=$SCRIPT_DIR/$EXP_NAME |
| 12 | +LOG_DIR=$EXP_DIR/logs |
| 13 | +JSON_METRICS=$EXP_DIR/metrics.json |
| 14 | +RUN_LOG=$EXP_DIR/run.log |
| 15 | +CHECKPOINT_DIR=$EXP_DIR/checkpoints |
| 16 | +DATA_DIR=$EXP_DIR/data |
| 17 | +export PYTHONPATH=${PROJECT_ROOT}:${PYTHONPATH:-} |
| 18 | + |
| 19 | +rm -rf $EXP_DIR $LOG_DIR |
| 20 | +mkdir -p $EXP_DIR $LOG_DIR $CHECKPOINT_DIR $DATA_DIR |
| 21 | + |
| 22 | +cd $PROJECT_ROOT |
| 23 | + |
| 24 | +# Follow nemo-gym instructions here to get this data: |
| 25 | +# https://docs.nvidia.com/nemo/gym/0.1.0/tutorials/nemo-rl-grpo/setup.html#training-nemo-rl-grpo-setup |
| 26 | +cd 3rdparty/Gym-workspace/Gym |
| 27 | + |
| 28 | +# We need HF_TOKEN to download the data from huggingface |
| 29 | +if [[ ! -f env.yaml ]]; then |
| 30 | + if [[ -z "${HF_TOKEN:-}" ]]; then |
| 31 | + echo "[ERROR] HF_TOKEN is not set" |
| 32 | + exit 1 |
| 33 | + fi |
| 34 | + echo "hf_token: $HF_TOKEN" >> env.yaml |
| 35 | +fi |
| 36 | + |
| 37 | +config_paths="responses_api_models/vllm_model/configs/vllm_model_for_training.yaml,\ |
| 38 | +resources_servers/workplace_assistant/configs/workplace_assistant.yaml" |
| 39 | + |
| 40 | +uv run ng_prepare_data "+config_paths=[${config_paths}]" \ |
| 41 | + +output_dirpath=data/workplace_assistant \ |
| 42 | + +mode=train_preparation \ |
| 43 | + +should_download=true \ |
| 44 | + +data_source=huggingface |
| 45 | +cd - |
| 46 | + |
| 47 | +# This trimming of the workplace assistant dataset is necessary b/c with all the tools the first prompt is >4000 tokens |
| 48 | +# which will cause vllm to return nothing on the first prompt and crash RL. Since we want to keep this test short to |
| 49 | +# smoke test, we trim all but the first tool |
| 50 | +TRAIN_PATH=$DATA_DIR/workplace_assistant_train.jsonl |
| 51 | +VALIDATION_PATH=$DATA_DIR/workplace_assistant_validation.jsonl |
| 52 | +jq -c '.responses_create_params.tools |= (.[0:1])' 3rdparty/Gym-workspace/Gym/data/workplace_assistant/train.jsonl > $TRAIN_PATH |
| 53 | +jq -c '.responses_create_params.tools |= (.[0:1])' 3rdparty/Gym-workspace/Gym/data/workplace_assistant/validation.jsonl > $VALIDATION_PATH |
| 54 | + |
| 55 | +uv run coverage run -a --data-file=$PROJECT_ROOT/tests/.coverage --source=$PROJECT_ROOT/nemo_rl \ |
| 56 | + $PROJECT_ROOT/examples/nemo_gym/run_grpo_nemo_gym.py \ |
| 57 | + --config $PROJECT_ROOT/examples/nemo_gym/grpo_qwen3_30ba3b_instruct.yaml \ |
| 58 | + policy.model_name=Qwen/Qwen3-0.6B \ |
| 59 | + policy.dtensor_cfg.enabled=false \ |
| 60 | + policy.megatron_cfg.enabled=true \ |
| 61 | + policy.megatron_cfg.tensor_model_parallel_size=1 \ |
| 62 | + policy.megatron_cfg.pipeline_model_parallel_size=1 \ |
| 63 | + policy.megatron_cfg.expert_model_parallel_size=1 \ |
| 64 | + policy.megatron_cfg.context_parallel_size=1 \ |
| 65 | + policy.megatron_cfg.sequence_parallel=false \ |
| 66 | + policy.generation.vllm_cfg.tensor_parallel_size=1 \ |
| 67 | + policy.generation.vllm_cfg.async_engine=true \ |
| 68 | + policy.max_total_sequence_length=512 \ |
| 69 | + policy.generation.colocated.enabled=false \ |
| 70 | + policy.generation.colocated.resources.num_nodes=1 \ |
| 71 | + policy.generation.colocated.resources.gpus_per_node=1 \ |
| 72 | + grpo.num_prompts_per_step=4 \ |
| 73 | + grpo.num_generations_per_prompt=2 \ |
| 74 | + grpo.max_num_steps=10 \ |
| 75 | + grpo.async_grpo.enabled=true \ |
| 76 | + grpo.async_grpo.max_trajectory_age_steps=1 \ |
| 77 | + grpo.async_grpo.in_flight_weight_updates=true \ |
| 78 | + policy.train_global_batch_size=4 \ |
| 79 | + policy.train_micro_batch_size=1 \ |
| 80 | + cluster.gpus_per_node=2 \ |
| 81 | + loss_fn.use_importance_sampling_correction=true \ |
| 82 | + logger.tensorboard_enabled=true \ |
| 83 | + logger.log_dir=$LOG_DIR \ |
| 84 | + logger.wandb_enabled=false \ |
| 85 | + logger.monitor_gpus=true \ |
| 86 | + checkpointing.enabled=false \ |
| 87 | + checkpointing.checkpoint_dir=$CHECKPOINT_DIR \ |
| 88 | + data.train.data_path=$TRAIN_PATH \ |
| 89 | + data.validation.data_path=$VALIDATION_PATH \ |
| 90 | + $@ \ |
| 91 | + 2>&1 | tee $RUN_LOG |
| 92 | + |
| 93 | +uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS |
| 94 | + |
| 95 | +# Observed to be between 0.8-1.3 |
| 96 | +uv run tests/check_metrics.py $JSON_METRICS \ |
| 97 | + 'median(data["train/gen_kl_error"]) < 1.3' |
0 commit comments