Skip to content

Commit 769e3e8

Browse files
authored
Update RLTrainer init variables to match TorchTitan JobConfig (#293)
1 parent 1211ecd commit 769e3e8

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

src/forge/actors/trainer.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,13 @@
2727
Checkpoint,
2828
Comm,
2929
Compile,
30-
Float8Linear,
30+
Job,
3131
LRScheduler,
32+
MemoryEstimation,
3233
Model,
3334
Optimizer,
3435
Parallelism,
36+
Quantize,
3537
Training,
3638
)
3739
from torchtitan.experiments.forge.engine import ForgeEngine
@@ -93,6 +95,7 @@ def cleanup_old_weight_versions(
9395

9496
@dataclass
9597
class RLTrainer(ForgeActor):
98+
job: Job = field(default_factory=Job)
9699
model: Model = field(default_factory=Model)
97100
optimizer: Optimizer = field(default_factory=Optimizer)
98101
lr_scheduler: LRScheduler = field(default_factory=LRScheduler)
@@ -102,15 +105,17 @@ class RLTrainer(ForgeActor):
102105
activation_checkpoint: ActivationCheckpoint = field(
103106
default_factory=ActivationCheckpoint
104107
)
105-
use_vllm_builtin_load: bool = True
106108
compile: Compile = field(default_factory=Compile)
107-
float8: Float8Linear = field(default_factory=Float8Linear)
109+
quantize: Quantize = field(default_factory=Quantize)
108110
comm: Comm = field(default_factory=Comm)
111+
memory_estimation: MemoryEstimation = field(default_factory=MemoryEstimation)
112+
# Non JobConfig-related fields
109113
loss: Callable = lambda logits, **targets: logits
110114
state_dict_key: str = "model_state_dict"
111115
use_dcp: bool = True
112116
dcp_path: str = "forge_dcp_tmp"
113117
vllm_tp_DEPRECATED: int = 1 # noqa: N815
118+
use_vllm_builtin_load: bool = True
114119

115120
def __post_init__(self):
116121
"""Initializes config types and env variables.

0 commit comments

Comments
 (0)