Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions apps/grpo/qwen3_32b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
# NOTE - This has not been tested for correctness yet! All testing so far has been only for infrastructure stability

# Global configuration
group_size: 2
local_batch_size: 8 # per-device batch size
max_req_tokens: 512
max_res_tokens: 512
group_size: 16
local_batch_size: 32 # per-device batch size
max_req_tokens: 1024
max_res_tokens: 1024
model: "Qwen/Qwen3-32B"
off_by_n: 1 # Off by one by default

provisioner:
launcher: slurm

# Main loop configuration
rollout_threads: 1 # Recommended to set equal to policy.num_replicas
rollout_threads: 32 # make this 4x the number of policy replicas seems to work well

# Observability configuration
metric_logging:
Expand Down Expand Up @@ -69,8 +69,8 @@ trainer:
enable: false
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: -1
tensor_parallel_degree: 1
data_parallel_shard_degree: 1
tensor_parallel_degree: 8
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
Expand All @@ -90,7 +90,7 @@ replay_buffer:
batch_size: ${local_batch_size}
max_policy_age: ${off_by_n}
# dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree
dp_size: 8
dp_size: 1

# Reference model configuration
ref_model:
Expand Down Expand Up @@ -119,7 +119,7 @@ ref_model:
services:
policy:
procs: ${policy.engine_config.tensor_parallel_size}
num_replicas: 1
num_replicas: 4
hosts: 1
with_gpus: true
ref_model:
Expand Down
6 changes: 3 additions & 3 deletions apps/mast/qwen3_14b_mast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# Global configuration
group_size: 8
batch_size: 16
local_batch_size: 16 # per-device batch size
max_req_tokens: 512
max_res_tokens: 512
model: "Qwen/Qwen3-14B"
Expand Down Expand Up @@ -61,7 +61,7 @@ trainer:
lr_scheduler:
warmup_steps: 1
training:
local_batch_size: ${batch_size}
local_batch_size: ${local_batch_size}
seq_len: 2048
max_norm: 1.0
steps: 1000000
Expand Down Expand Up @@ -95,7 +95,7 @@ trainer:

# Replay buffer configuration
replay_buffer:
batch_size: ${batch_size}
batch_size: ${local_batch_size}
max_policy_age: ${off_by_n}
dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree

Expand Down
6 changes: 3 additions & 3 deletions apps/mast/qwen3_1_7b_mast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# Global configuration
group_size: 8
batch_size: 16
local_batch_size: 16 # per-device batch size
max_req_tokens: 512
max_res_tokens: 512
model: "Qwen/Qwen3-1.7B"
Expand Down Expand Up @@ -61,7 +61,7 @@ trainer:
lr_scheduler:
warmup_steps: 1
training:
local_batch_size: ${batch_size}
local_batch_size: ${local_batch_size}
seq_len: 2048
max_norm: 1.0
steps: 1000000
Expand Down Expand Up @@ -95,7 +95,7 @@ trainer:

# Replay buffer configuration
replay_buffer:
batch_size: ${batch_size}
batch_size: ${local_batch_size}
max_policy_age: ${off_by_n}
dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree

Expand Down
24 changes: 12 additions & 12 deletions apps/mast/qwen3_32b_mast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,18 @@
# >>> python -m apps.mast.main --config apps/mast/qwen3_1_7b_mast.yaml

# Global configuration
group_size: 8
batch_size: 16
max_req_tokens: 512
max_res_tokens: 512
group_size: 16
local_batch_size: 32 # per-device batch size
max_req_tokens: 1024
max_res_tokens: 1024
model: "Qwen/Qwen3-32B"
off_by_n: 1 # Off by one by default
launcher: mast
job_name: forge-qwen3-32b
checkpoint_folder: /mnt/wsfuse/teamforge/forge_runs/

# Main loop configuration
rollout_threads: ${services.policy.num_replicas} # Recommended to set equal to policy.num_replicas
rollout_threads: 32 # make this 4x the number of policy replicas seems to work well

# Observability configuration
metric_logging:
Expand Down Expand Up @@ -61,7 +61,7 @@ trainer:
lr_scheduler:
warmup_steps: 1
training:
local_batch_size: ${batch_size}
local_batch_size: ${local_batch_size}
seq_len: 2048
max_norm: 1.0
steps: 1000000
Expand All @@ -71,8 +71,8 @@ trainer:
enable: false
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: 8
tensor_parallel_degree: 1
data_parallel_shard_degree: 1
tensor_parallel_degree: 8
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
Expand All @@ -95,7 +95,7 @@ trainer:

# Replay buffer configuration
replay_buffer:
batch_size: ${batch_size}
batch_size: ${local_batch_size}
max_policy_age: ${off_by_n}
dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree

Expand Down Expand Up @@ -129,13 +129,13 @@ ref_model:
services:
policy:
procs: ${policy.engine_config.tensor_parallel_size}
num_replicas: 2
num_replicas: 4
with_gpus: true
mesh_name: policy
hosts: 1
ref_model:
procs: 4
num_replicas: 2
procs: ${ref_model.parallelism.tensor_parallel_degree}
num_replicas: 1
with_gpus: true
mesh_name: ref_model
hosts: 1
Expand Down
6 changes: 3 additions & 3 deletions apps/mast/qwen3_4b_mast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# Global configuration
group_size: 8
batch_size: 16
local_batch_size: 16 # per-device batch size
max_req_tokens: 512
max_res_tokens: 512
model: "Qwen/Qwen3-4B"
Expand Down Expand Up @@ -61,7 +61,7 @@ trainer:
lr_scheduler:
warmup_steps: 1
training:
local_batch_size: ${batch_size}
local_batch_size: ${local_batch_size}
seq_len: 2048
max_norm: 1.0
steps: 1000000
Expand Down Expand Up @@ -95,7 +95,7 @@ trainer:

# Replay buffer configuration
replay_buffer:
batch_size: ${batch_size}
batch_size: ${local_batch_size}
max_policy_age: ${off_by_n}
dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree

Expand Down
6 changes: 3 additions & 3 deletions apps/mast/qwen3_8b_mast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# Global configuration
group_size: 8
batch_size: 16
local_batch_size: 16 # per-device batch size
max_req_tokens: 512
max_res_tokens: 512
model: "Qwen/Qwen3-8B"
Expand Down Expand Up @@ -61,7 +61,7 @@ trainer:
lr_scheduler:
warmup_steps: 1
training:
local_batch_size: ${batch_size}
local_batch_size: ${local_batch_size}
seq_len: 2048
max_norm: 1.0
steps: 1000000
Expand Down Expand Up @@ -95,7 +95,7 @@ trainer:

# Replay buffer configuration
replay_buffer:
batch_size: ${batch_size}
batch_size: ${local_batch_size}
max_policy_age: ${off_by_n}
dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree

Expand Down
Loading