Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions apps/grpo/qwen3_32b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
# NOTE - This has not been tested for correctness yet! All testing so far has been only for infrastructure stability

# Global configuration
group_size: 2
batch_size: 8
max_req_tokens: 512
max_res_tokens: 512
group_size: 16
batch_size: 32
max_req_tokens: 1024
max_res_tokens: 1024
model: "Qwen/Qwen3-32B"
off_by_n: 1 # Off by one by default

provisioner:
launcher: slurm

# Main loop configuration
rollout_threads: 1 # Recommended to set equal to policy.num_replicas
rollout_threads: 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we could double this. I made rollout_threads to 4 with 1 policy replica. And the waiting for buffer time went from 300s to 60s.


# Observability configuration
metric_logging:
Expand Down Expand Up @@ -69,8 +69,8 @@ trainer:
enable: false
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: -1
tensor_parallel_degree: 1
data_parallel_shard_degree: 1
tensor_parallel_degree: 8
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
Expand All @@ -90,7 +90,7 @@ replay_buffer:
batch_size: ${batch_size}
max_policy_age: ${off_by_n}
# dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree
dp_size: 8
dp_size: 1

# Reference model configuration
ref_model:
Expand Down Expand Up @@ -119,7 +119,7 @@ ref_model:
services:
policy:
procs: ${policy.engine_config.tensor_parallel_size}
num_replicas: 1
num_replicas: 4
hosts: 1
with_gpus: true
ref_model:
Expand Down
20 changes: 10 additions & 10 deletions apps/mast/qwen3_32b_mast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
# >>> python -m apps.mast.main --config apps/mast/qwen3_1_7b_mast.yaml

# Global configuration
group_size: 8
batch_size: 16
max_req_tokens: 512
max_res_tokens: 512
group_size: 16
batch_size: 32
max_req_tokens: 1024
max_res_tokens: 1024
model: "Qwen/Qwen3-32B"
off_by_n: 1 # Off by one by default
launcher: mast
job_name: forge-qwen3-32b
checkpoint_folder: /mnt/wsfuse/teamforge/forge_runs/

# Main loop configuration
rollout_threads: ${services.policy.num_replicas} # Recommended to set equal to policy.num_replicas
rollout_threads: 8

# Observability configuration
metric_logging:
Expand Down Expand Up @@ -71,8 +71,8 @@ trainer:
enable: false
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 8
tensor_parallel_degree: 1
data_parallel_shard_degree: 1
tensor_parallel_degree: 8
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
Expand Down Expand Up @@ -129,13 +129,13 @@ ref_model:
services:
policy:
procs: ${policy.engine_config.tensor_parallel_size}
num_replicas: 2
num_replicas: 4
with_gpus: true
mesh_name: policy
hosts: 1
ref_model:
procs: 4
num_replicas: 2
procs: ${ref_model.parallelism.tensor_parallel_degree}
num_replicas: 1
with_gpus: true
mesh_name: ref_model
hosts: 1
Expand Down
Loading