diff --git a/apps/grpo/qwen3_32b.yaml b/apps/grpo/qwen3_32b.yaml index 593a2e1fb..e7a0cf509 100644 --- a/apps/grpo/qwen3_32b.yaml +++ b/apps/grpo/qwen3_32b.yaml @@ -3,10 +3,10 @@ # NOTE - This has not been tested for correctness yet! All testing so far has been only for infrastructure stability # Global configuration -group_size: 2 -local_batch_size: 8 # per-device batch size -max_req_tokens: 512 -max_res_tokens: 512 +group_size: 16 +local_batch_size: 32 # per-device batch size +max_req_tokens: 1024 +max_res_tokens: 1024 model: "Qwen/Qwen3-32B" off_by_n: 1 # Off by one by default @@ -14,7 +14,7 @@ provisioner: launcher: slurm # Main loop configuration -rollout_threads: 1 # Recommended to set equal to policy.num_replicas +rollout_threads: 32 # make this 4x the number of policy replicas seems to work well # Observability configuration metric_logging: @@ -69,8 +69,8 @@ trainer: enable: false parallelism: data_parallel_replicate_degree: 1 - data_parallel_shard_degree: -1 - tensor_parallel_degree: 1 + data_parallel_shard_degree: 1 + tensor_parallel_degree: 8 pipeline_parallel_degree: 1 context_parallel_degree: 1 expert_parallel_degree: 1 @@ -90,7 +90,7 @@ replay_buffer: batch_size: ${local_batch_size} max_policy_age: ${off_by_n} # dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree - dp_size: 8 + dp_size: 1 # Reference model configuration ref_model: @@ -119,7 +119,7 @@ ref_model: services: policy: procs: ${policy.engine_args.tensor_parallel_size} - num_replicas: 1 + num_replicas: 4 hosts: 1 with_gpus: true mesh_name: policy