diff --git a/.meta/mast/qwen3_14b_mast.yaml b/.meta/mast/qwen3_14b_mast.yaml index 9560db4e1..d5cfae9e5 100644 --- a/.meta/mast/qwen3_14b_mast.yaml +++ b/.meta/mast/qwen3_14b_mast.yaml @@ -61,7 +61,7 @@ trainer: warmup_steps: 1 training: local_batch_size: ${local_batch_size} - seq_len: 2048 + seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens max_norm: 1.0 steps: 1000000 dtype: bfloat16 @@ -106,6 +106,7 @@ ref_model: flavor: 14B hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_14b training: + seq_len: ${trainer.training.seq_len} dtype: bfloat16 gc_freq: 1 compile: diff --git a/.meta/mast/qwen3_1_7b_mast.yaml b/.meta/mast/qwen3_1_7b_mast.yaml index 604fc4f4e..f859b77be 100644 --- a/.meta/mast/qwen3_1_7b_mast.yaml +++ b/.meta/mast/qwen3_1_7b_mast.yaml @@ -62,7 +62,7 @@ trainer: warmup_steps: 1 training: local_batch_size: ${local_batch_size} - seq_len: 2048 + seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens max_norm: 1.0 steps: 1000000 dtype: bfloat16 @@ -108,6 +108,7 @@ ref_model: hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_1.7b # hf_assets_path: hf://${model} training: + seq_len: ${trainer.training.seq_len} dtype: bfloat16 gc_freq: 1 compile: diff --git a/.meta/mast/qwen3_32b_mast.yaml b/.meta/mast/qwen3_32b_mast.yaml index b9079a2c2..a606f2f8e 100644 --- a/.meta/mast/qwen3_32b_mast.yaml +++ b/.meta/mast/qwen3_32b_mast.yaml @@ -61,7 +61,7 @@ trainer: warmup_steps: 1 training: local_batch_size: ${local_batch_size} - seq_len: 2048 + seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens max_norm: 1.0 steps: 1000000 dtype: bfloat16 @@ -106,6 +106,7 @@ ref_model: flavor: 32B hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_32b training: + seq_len: ${trainer.training.seq_len} dtype: bfloat16 gc_freq: 1 compile: diff --git a/.meta/mast/qwen3_4b_mast.yaml b/.meta/mast/qwen3_4b_mast.yaml index 5e7442c12..15bb08a7f 100644 --- a/.meta/mast/qwen3_4b_mast.yaml +++ b/.meta/mast/qwen3_4b_mast.yaml @@ -62,7 +62,7 @@ trainer: warmup_steps: 1 training: local_batch_size: ${local_batch_size} - seq_len: 2048 + seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens max_norm: 1.0 steps: 1000000 dtype: bfloat16 @@ -108,6 +108,7 @@ ref_model: hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_4b # hf_assets_path: hf://${model} training: + seq_len: ${trainer.training.seq_len} dtype: bfloat16 gc_freq: 1 compile: diff --git a/.meta/mast/qwen3_8b_mast.yaml b/.meta/mast/qwen3_8b_mast.yaml index ec90db0ff..2b2d8b2a8 100644 --- a/.meta/mast/qwen3_8b_mast.yaml +++ b/.meta/mast/qwen3_8b_mast.yaml @@ -61,7 +61,7 @@ trainer: warmup_steps: 1 training: local_batch_size: ${local_batch_size} - seq_len: 2048 + seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens max_norm: 1.0 steps: 1000000 dtype: bfloat16 @@ -106,6 +106,7 @@ ref_model: flavor: 8B hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_8b training: + seq_len: ${trainer.training.seq_len} dtype: bfloat16 gc_freq: 1 compile: diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index 6bb2ebab3..c6fc1613b 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -4,8 +4,8 @@ # Global configuration group_size: 8 local_batch_size: 16 # per-device batch size -max_req_tokens: 512 -max_res_tokens: 512 +max_req_tokens: 1024 +max_res_tokens: 1024 model: "Qwen/Qwen3-1.7B" off_by_n: 1 # Off by one by default @@ -57,7 +57,7 @@ trainer: warmup_steps: 1 training: local_batch_size: ${local_batch_size} - seq_len: 2048 + seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens max_norm: 1.0 steps: 1000000 dtype: bfloat16 diff --git a/apps/grpo/qwen3_32b.yaml b/apps/grpo/qwen3_32b.yaml index 67d9e3a77..639f6669e 100644 --- a/apps/grpo/qwen3_32b.yaml +++ b/apps/grpo/qwen3_32b.yaml @@ -60,7 +60,7 @@ trainer: warmup_steps: 1 training: local_batch_size: ${local_batch_size} - seq_len: 2048 + seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens max_norm: 1.0 steps: 1000000 dtype: bfloat16 diff --git a/apps/grpo/qwen3_8b.yaml b/apps/grpo/qwen3_8b.yaml index 683aa1503..22a4a3961 100644 --- a/apps/grpo/qwen3_8b.yaml +++ b/apps/grpo/qwen3_8b.yaml @@ -4,8 +4,8 @@ # Global configuration group_size: 8 local_batch_size: 16 # per-device batch size -max_req_tokens: 512 -max_res_tokens: 512 +max_req_tokens: 1024 +max_res_tokens: 1024 model: "Qwen/Qwen3-8B" off_by_n: 1 # Off by one by default @@ -53,7 +53,7 @@ trainer: warmup_steps: 1 training: local_batch_size: ${local_batch_size} - seq_len: 2048 + seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens max_norm: 1.0 steps: 1000000 dtype: bfloat16 diff --git a/src/forge/util/config.py b/src/forge/util/config.py index 2dd171ae3..0315ca525 100644 --- a/src/forge/util/config.py +++ b/src/forge/util/config.py @@ -15,6 +15,9 @@ from omegaconf import DictConfig, OmegaConf +# Add support for summing lists of numbers, e.g. ${sum:${max_req_tokens},${max_res_tokens}} +OmegaConf.register_new_resolver("sum", lambda *args: sum(args), replace=True) + def _has_component(node: Any) -> bool: """Check if a node has a _component_ field.""" diff --git a/tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml b/tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml index 64588b6d4..5e0bcf17f 100644 --- a/tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml +++ b/tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml @@ -34,7 +34,7 @@ trainer: warmup_steps: 1 training: local_batch_size: ${batch_size} - seq_len: 2048 + seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens max_norm: 1.0 steps: 1000000 dtype: bfloat16 diff --git a/tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml b/tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml index cf3d7dc80..1f05daab4 100644 --- a/tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml +++ b/tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml @@ -36,7 +36,7 @@ trainer: warmup_steps: 1 training: local_batch_size: ${batch_size} - seq_len: 2048 + seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens max_norm: 1.0 steps: 1000000 dtype: bfloat16