Skip to content

Commit 3186797

Browse files
authored
Update torchtitan for proper bf16 & new quant APIs (#281)
1 parent 1f74869 commit 3186797

File tree

4 files changed

+5
-5
lines changed

4 files changed

+5
-5
lines changed

apps/grpo/qwen3_8b.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
# Global configuration
55
group_size: 8
66
batch_size: 16
7-
max_req_tokens: 468
8-
max_res_tokens: 468
7+
max_req_tokens: 512
8+
max_res_tokens: 512
99
model: "Qwen/Qwen3-8B"
1010
off_by_n: 1 # Off by one by default
1111

757 Bytes
Binary file not shown.

scripts/build_wheels.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ NC='\033[0m'
1818
PYTORCH_VERSION="2.9.0.dev20250905"
1919
VLLM_BRANCH="v0.10.0"
2020
MONARCH_COMMIT="9c41b5c16edadeab7cfb8521ba7efe68a1e2bc87"
21-
TORCHTITAN_COMMIT="a3104201ba3a0fa19e9c3cc5ba748b0398551410"
21+
TORCHTITAN_COMMIT="9f3fe08635356b829e6bf41883760679a8207697"
2222
TORCHSTORE_COMMIT="0052f6d8b686b9cff0cf4ce203a836c4b5d5ac94"
2323
BUILD_DIR="$HOME/forge-build"
2424
WHEEL_DIR="$(pwd)/assets/wheels"

src/forge/actors/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
Checkpoint,
2828
Comm,
2929
Compile,
30-
Float8Dense,
30+
Float8Linear,
3131
LRScheduler,
3232
Model,
3333
Optimizer,
@@ -104,7 +104,7 @@ class RLTrainer(ForgeActor):
104104
)
105105
use_vllm_builtin_load: bool = True
106106
compile: Compile = field(default_factory=Compile)
107-
float8: Float8Dense = field(default_factory=Float8Dense)
107+
float8: Float8Linear = field(default_factory=Float8Linear)
108108
comm: Comm = field(default_factory=Comm)
109109
loss: Callable = lambda logits, **targets: logits
110110
state_dict_key: str = "model_state_dict"

0 commit comments

Comments
 (0)