Skip to content

Commit 374a8fa

Browse files
felipemello1Felipe Mello
andauthored
[FIX] Remove hardcoded seq len (meta-pytorch#497)
Co-authored-by: Felipe Mello <[email protected]>
1 parent a73756c commit 374a8fa

File tree

11 files changed

+22
-14
lines changed

11 files changed

+22
-14
lines changed

.meta/mast/qwen3_14b_mast.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ trainer:
6161
warmup_steps: 1
6262
training:
6363
local_batch_size: ${local_batch_size}
64-
seq_len: 2048
64+
seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
6565
max_norm: 1.0
6666
steps: 1000000
6767
dtype: bfloat16
@@ -106,6 +106,7 @@ ref_model:
106106
flavor: 14B
107107
hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_14b
108108
training:
109+
seq_len: ${trainer.training.seq_len}
109110
dtype: bfloat16
110111
gc_freq: 1
111112
compile:

.meta/mast/qwen3_1_7b_mast.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ trainer:
6262
warmup_steps: 1
6363
training:
6464
local_batch_size: ${local_batch_size}
65-
seq_len: 2048
65+
seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
6666
max_norm: 1.0
6767
steps: 1000000
6868
dtype: bfloat16
@@ -108,6 +108,7 @@ ref_model:
108108
hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_1.7b
109109
# hf_assets_path: hf://${model}
110110
training:
111+
seq_len: ${trainer.training.seq_len}
111112
dtype: bfloat16
112113
gc_freq: 1
113114
compile:

.meta/mast/qwen3_32b_mast.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ trainer:
6161
warmup_steps: 1
6262
training:
6363
local_batch_size: ${local_batch_size}
64-
seq_len: 2048
64+
seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
6565
max_norm: 1.0
6666
steps: 1000000
6767
dtype: bfloat16
@@ -106,6 +106,7 @@ ref_model:
106106
flavor: 32B
107107
hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_32b
108108
training:
109+
seq_len: ${trainer.training.seq_len}
109110
dtype: bfloat16
110111
gc_freq: 1
111112
compile:

.meta/mast/qwen3_4b_mast.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ trainer:
6262
warmup_steps: 1
6363
training:
6464
local_batch_size: ${local_batch_size}
65-
seq_len: 2048
65+
seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
6666
max_norm: 1.0
6767
steps: 1000000
6868
dtype: bfloat16
@@ -108,6 +108,7 @@ ref_model:
108108
hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_4b
109109
# hf_assets_path: hf://${model}
110110
training:
111+
seq_len: ${trainer.training.seq_len}
111112
dtype: bfloat16
112113
gc_freq: 1
113114
compile:

.meta/mast/qwen3_8b_mast.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ trainer:
6161
warmup_steps: 1
6262
training:
6363
local_batch_size: ${local_batch_size}
64-
seq_len: 2048
64+
seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
6565
max_norm: 1.0
6666
steps: 1000000
6767
dtype: bfloat16
@@ -106,6 +106,7 @@ ref_model:
106106
flavor: 8B
107107
hf_assets_path: /mnt/wsfuse/teamforge/hf/qwen3_8b
108108
training:
109+
seq_len: ${trainer.training.seq_len}
109110
dtype: bfloat16
110111
gc_freq: 1
111112
compile:

apps/grpo/qwen3_1_7b.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
# Global configuration
55
group_size: 8
66
local_batch_size: 16 # per-device batch size
7-
max_req_tokens: 512
8-
max_res_tokens: 512
7+
max_req_tokens: 1024
8+
max_res_tokens: 1024
99
model: "Qwen/Qwen3-1.7B"
1010
off_by_n: 1 # Off by one by default
1111

@@ -57,7 +57,7 @@ trainer:
5757
warmup_steps: 1
5858
training:
5959
local_batch_size: ${local_batch_size}
60-
seq_len: 2048
60+
seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
6161
max_norm: 1.0
6262
steps: 1000000
6363
dtype: bfloat16

apps/grpo/qwen3_32b.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ trainer:
6060
warmup_steps: 1
6161
training:
6262
local_batch_size: ${local_batch_size}
63-
seq_len: 2048
63+
seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
6464
max_norm: 1.0
6565
steps: 1000000
6666
dtype: bfloat16

apps/grpo/qwen3_8b.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
# Global configuration
55
group_size: 8
66
local_batch_size: 16 # per-device batch size
7-
max_req_tokens: 512
8-
max_res_tokens: 512
7+
max_req_tokens: 1024
8+
max_res_tokens: 1024
99
model: "Qwen/Qwen3-8B"
1010
off_by_n: 1 # Off by one by default
1111

@@ -53,7 +53,7 @@ trainer:
5353
warmup_steps: 1
5454
training:
5555
local_batch_size: ${local_batch_size}
56-
seq_len: 2048
56+
seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
5757
max_norm: 1.0
5858
steps: 1000000
5959
dtype: bfloat16

src/forge/util/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515

1616
from omegaconf import DictConfig, OmegaConf
1717

18+
# Add support for summing lists of numbers, e.g. ${sum:${max_req_tokens},${max_res_tokens}}
19+
OmegaConf.register_new_resolver("sum", lambda *args: sum(args), replace=True)
20+
1821

1922
def _has_component(node: Any) -> bool:
2023
"""Check if a node has a _component_ field."""

tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ trainer:
3434
warmup_steps: 1
3535
training:
3636
local_batch_size: ${batch_size}
37-
seq_len: 2048
37+
seq_len: ${sum:${max_req_tokens},${max_res_tokens}} # seq_len >= max_req_tokens + max_res_tokens
3838
max_norm: 1.0
3939
steps: 1000000
4040
dtype: bfloat16

0 commit comments

Comments
 (0)